[clang] [Serialization] Load Specializations Lazily (PR #76774)

Chuanqi Xu via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 22 01:45:44 PST 2024


https://github.com/ChuanqiXu9 updated https://github.com/llvm/llvm-project/pull/76774

>From 26bd7dc5139811b2b0d8d8642a67b67340eeb1d5 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Wed, 3 Jan 2024 11:33:17 +0800
Subject: [PATCH 1/4] Load Specializations Lazily

---
 clang/include/clang/AST/DeclTemplate.h        |  35 ++--
 clang/include/clang/AST/ExternalASTSource.h   |   6 +
 clang/include/clang/AST/ODRHash.h             |   3 +
 .../clang/Sema/MultiplexExternalSemaSource.h  |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   3 +
 clang/include/clang/Serialization/ASTReader.h |  28 +++
 clang/include/clang/Serialization/ASTWriter.h |   6 +
 clang/lib/AST/DeclTemplate.cpp                |  67 +++++---
 clang/lib/AST/ExternalASTSource.cpp           |   5 +
 clang/lib/AST/ODRHash.cpp                     |   3 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |   6 +
 clang/lib/Serialization/ASTReader.cpp         | 125 +++++++++++++-
 clang/lib/Serialization/ASTReaderDecl.cpp     |  35 +++-
 clang/lib/Serialization/ASTReaderInternals.h  |  80 +++++++++
 clang/lib/Serialization/ASTWriter.cpp         | 144 +++++++++++++++-
 clang/lib/Serialization/ASTWriterDecl.cpp     |  78 ++++++---
 clang/test/Modules/odr_hash.cpp               |   4 +-
 clang/unittests/Serialization/CMakeLists.txt  |   1 +
 .../Serialization/LoadSpecLazily.cpp          | 159 ++++++++++++++++++
 19 files changed, 729 insertions(+), 65 deletions(-)
 create mode 100644 clang/unittests/Serialization/LoadSpecLazily.cpp

diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 832ad2de6b08a8..4699dd17bc182c 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -525,8 +525,11 @@ class FunctionTemplateSpecializationInfo final
     return Function.getInt();
   }
 
+  void loadExternalRedecls();
+
 public:
   friend TrailingObjects;
+  friend class ASTReader;
 
   static FunctionTemplateSpecializationInfo *
   Create(ASTContext &C, FunctionDecl *FD, FunctionTemplateDecl *Template,
@@ -789,13 +792,15 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     return SpecIterator<EntryType>(isEnd ? Specs.end() : Specs.begin());
   }
 
-  void loadLazySpecializationsImpl() const;
+  void loadExternalSpecializations() const;
 
   template <class EntryType, typename ...ProfileArguments>
   typename SpecEntryTraits<EntryType>::DeclType*
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                          void *&InsertPos, ProfileArguments &&...ProfileArgs);
 
+  void loadLazySpecializationsWithArgs(ArrayRef<TemplateArgument> TemplateArgs);
+
   template <class Derived, class EntryType>
   void addSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                              EntryType *Entry, void *InsertPos);
@@ -814,9 +819,13 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     /// If non-null, points to an array of specializations (including
     /// partial specializations) known only by their external declaration IDs.
     ///
+    /// These specializations needs to be loaded at once in
+    /// loadExternalSpecializations to complete the redecl chain or be preparing
+    /// for template resolution.
+    ///
     /// The first value in the array is the number of specializations/partial
     /// specializations that follow.
-    uint32_t *LazySpecializations = nullptr;
+    uint32_t *ExternalSpecializations = nullptr;
 
     /// The set of "injected" template arguments used within this
     /// template.
@@ -850,6 +859,8 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   friend class ASTDeclWriter;
   friend class ASTReader;
   template <class decl_type> friend class RedeclarableTemplate;
+  friend class ClassTemplateSpecializationDecl;
+  friend class VarTemplateSpecializationDecl;
 
   /// Retrieves the canonical declaration of this template.
   RedeclarableTemplateDecl *getCanonicalDecl() override {
@@ -977,6 +988,7 @@ SpecEntryTraits<FunctionTemplateSpecializationInfo> {
 class FunctionTemplateDecl : public RedeclarableTemplateDecl {
 protected:
   friend class FunctionDecl;
+  friend class FunctionTemplateSpecializationInfo;
 
   /// Data that is common to all of the declarations of a given
   /// function template.
@@ -1012,13 +1024,13 @@ class FunctionTemplateDecl : public RedeclarableTemplateDecl {
   void addSpecialization(FunctionTemplateSpecializationInfo* Info,
                          void *InsertPos);
 
+  /// Load any lazily-loaded specializations from the external source.
+  void LoadLazySpecializations() const;
+
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
-
   /// Get the underlying function declaration of the template.
   FunctionDecl *getTemplatedDecl() const {
     return static_cast<FunctionDecl *>(TemplatedDecl);
@@ -1839,6 +1851,8 @@ class ClassTemplateSpecializationDecl
   LLVM_PREFERRED_TYPE(TemplateSpecializationKind)
   unsigned SpecializationKind : 3;
 
+  void loadExternalRedecls();
+
 protected:
   ClassTemplateSpecializationDecl(ASTContext &Context, Kind DK, TagKind TK,
                                   DeclContext *DC, SourceLocation StartLoc,
@@ -1852,6 +1866,7 @@ class ClassTemplateSpecializationDecl
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
 
   static ClassTemplateSpecializationDecl *
   Create(ASTContext &Context, TagKind TK, DeclContext *DC,
@@ -2285,9 +2300,7 @@ class ClassTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
   friend class TemplateDeclInstantiator;
-
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class ClassTemplateSpecializationDecl;
 
   /// Get the underlying class declarations of the template.
   CXXRecordDecl *getTemplatedDecl() const {
@@ -2651,6 +2664,8 @@ class VarTemplateSpecializationDecl : public VarDecl,
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsCompleteDefinition : 1;
 
+  void loadExternalRedecls();
+
 protected:
   VarTemplateSpecializationDecl(Kind DK, ASTContext &Context, DeclContext *DC,
                                 SourceLocation StartLoc, SourceLocation IdLoc,
@@ -2664,6 +2679,7 @@ class VarTemplateSpecializationDecl : public VarDecl,
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
   friend class VarDecl;
 
   static VarTemplateSpecializationDecl *
@@ -3057,8 +3073,7 @@ class VarTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class VarTemplatePartialSpecializationDecl;
 
   /// Get the underlying variable declarations of the template.
   VarDecl *getTemplatedDecl() const {
diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 8e573965b0a336..75ee6f827080d4 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -150,6 +150,12 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template arguments.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/AST/ODRHash.h b/clang/include/clang/AST/ODRHash.h
index a1caa6d39a87c3..d972be7f37c88b 100644
--- a/clang/include/clang/AST/ODRHash.h
+++ b/clang/include/clang/AST/ODRHash.h
@@ -104,6 +104,9 @@ class ODRHash {
 
   void AddStructuralValue(const APValue &);
 
+  // Add intergers to ID.
+  void AddInteger(unsigned Value);
+
   static bool isSubDeclToBeProcessed(const Decl *D, const DeclContext *Parent);
 
 private:
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 2bf91cb5212c5e..9345e96e8cb8c3 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template args.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fdd64f2abbe937..23a279de96ab15 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1523,6 +1523,9 @@ enum DeclCode {
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
+  // A decls specilization record.
+  DECL_SPECIALIZATIONS,
+
   DECL_LAST = DECL_IMPLICIT_CONCEPT_SPECIALIZATION
 };
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 21d791f5cd89a2..293d6495d164ef 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -340,6 +340,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct SpecializationsLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -599,6 +602,11 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  /// Map from decls to specialized decls.
+  llvm::DenseMap<const Decl *,
+                 serialization::reader::SpecializationsLookupTable>
+      SpecializationsLookups;
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
@@ -640,6 +648,9 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, serialization::DeclID ID);
 
+  bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                           uint64_t Offset, Decl *D);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1343,6 +1354,11 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::SpecializationsLookupTable *
+  getLoadedSpecializationsLookupTables(const Decl *D);
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -1982,6 +1998,10 @@ class ASTReader
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Read all of the declarations lexically stored in a
   /// declaration context.
   ///
@@ -2404,6 +2424,14 @@ class ASTReader
   bool isProcessingUpdateRecords() { return ProcessingUpdateRecords; }
 };
 
+/// Get a stable hash for template arguments across compiler invovations.
+/// This is used for loading corresponding specializations lazily according
+/// to the template arguments.
+/// Given it is fine to load additional specializations, we're tolerant to
+/// map different template arguments to the same hash value as long as the
+/// semantically same template arguments get the same hash value.
+unsigned GetTemplateArgsStableHash(ArrayRef<TemplateArgument> TemplateArgs);
+
 /// A simple helper class to unpack an integer to bits and consuming
 /// the bits in order.
 class BitsUnpacker {
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index de69f99003d827..09806b87590766 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -527,6 +527,10 @@ class ASTWriter : public ASTDeserializationListener,
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
   bool isLookupResultEntirelyExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  uint64_t WriteSpecializationsLookupTable(
+      const NamedDecl *D,
+      llvm::SmallVectorImpl<const NamedDecl *> &Specializations);
+
   void GenerateNameLookupTable(const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context, DeclContext *DC);
@@ -564,6 +568,8 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecializationsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 7d7556e670f951..16b396a5e785d7 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -331,14 +331,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
   return Common;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
+void RedeclarableTemplateDecl::loadExternalSpecializations() const {
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
-  if (CommonBasePtr->LazySpecializations) {
+  if (CommonBasePtr->ExternalSpecializations) {
     ASTContext &Context = getASTContext();
-    uint32_t *Specs = CommonBasePtr->LazySpecializations;
-    CommonBasePtr->LazySpecializations = nullptr;
+    uint32_t *Specs = CommonBasePtr->ExternalSpecializations;
+    CommonBasePtr->ExternalSpecializations = nullptr;
     for (uint32_t I = 0, N = *Specs++; I != N; ++I)
       (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
   }
@@ -358,6 +358,16 @@ RedeclarableTemplateDecl::findSpecializationImpl(
   return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
 }
 
+void RedeclarableTemplateDecl::loadLazySpecializationsWithArgs(
+    ArrayRef<TemplateArgument> TemplateArgs) {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                              TemplateArgs);
+}
+
 template<class Derived, class EntryType>
 void RedeclarableTemplateDecl::addSpecializationImpl(
     llvm::FoldingSetVector<EntryType> &Specializations, EntryType *Entry,
@@ -430,24 +440,23 @@ FunctionTemplateDecl::newCommon(ASTContext &C) const {
   return CommonPtr;
 }
 
-void FunctionTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<FunctionTemplateSpecializationInfo> &
 FunctionTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
+  using SETraits = SpecEntryTraits<FunctionTemplateSpecializationInfo>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(Info));
   addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
                                               InsertPos);
 }
@@ -508,19 +517,15 @@ ClassTemplateDecl *ClassTemplateDecl::CreateDeserialized(ASTContext &C,
                                        DeclarationName(), nullptr, nullptr);
 }
 
-void ClassTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<ClassTemplateSpecializationDecl> &
 ClassTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<ClassTemplatePartialSpecializationDecl> &
 ClassTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -534,11 +539,14 @@ ClassTemplateDecl::newCommon(ASTContext &C) const {
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
+  using SETraits = SpecEntryTraits<ClassTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
@@ -546,6 +554,7 @@ ClassTemplatePartialSpecializationDecl *
 ClassTemplateDecl::findPartialSpecialization(
     ArrayRef<TemplateArgument> Args,
     TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -900,6 +909,11 @@ FunctionTemplateSpecializationInfo *FunctionTemplateSpecializationInfo::Create(
       FD, Template, TSK, TemplateArgs, ArgsAsWritten, POI, MSInfo);
 }
 
+void FunctionTemplateSpecializationInfo::loadExternalRedecls() {
+  getTemplate()->loadExternalSpecializations();
+  getTemplate()->loadLazySpecializationsWithArgs(TemplateArguments->asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ClassTemplateSpecializationDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1024,6 +1038,12 @@ ClassTemplateSpecializationDecl::getSourceRange() const {
   }
 }
 
+void ClassTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ConceptDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1226,19 +1246,15 @@ VarTemplateDecl *VarTemplateDecl::CreateDeserialized(ASTContext &C,
                                      DeclarationName(), nullptr, nullptr);
 }
 
-void VarTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<VarTemplateSpecializationDecl> &
 VarTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<VarTemplatePartialSpecializationDecl> &
 VarTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -1252,17 +1268,21 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
+  using SETraits = SpecEntryTraits<VarTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *
 VarTemplateDecl::findPartialSpecialization(ArrayRef<TemplateArgument> Args,
      TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -1393,6 +1413,11 @@ SourceRange VarTemplateSpecializationDecl::getSourceRange() const {
   return VarDecl::getSourceRange();
 }
 
+void VarTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
 
 //===----------------------------------------------------------------------===//
 // VarTemplatePartialSpecializationDecl Implementation
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index 090ef02aa4224d..bba73fcbd8fa23 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -100,6 +100,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+void ExternalASTSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  return;
+}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/AST/ODRHash.cpp b/clang/lib/AST/ODRHash.cpp
index 5b98646a1e8dc0..e470c5bacf6d00 100644
--- a/clang/lib/AST/ODRHash.cpp
+++ b/clang/lib/AST/ODRHash.cpp
@@ -1363,3 +1363,6 @@ void ODRHash::AddStructuralValue(const APValue &Value) {
     Value.Profile(ID);
   }
 }
+
+void ODRHash::AddInteger(unsigned Value) { ID.AddInteger(Value); }
+
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index 058e22cb2b814e..b845383675ab36 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,12 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index fe8782a3eb9e7c..2e280c52b6e59c 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1223,6 +1223,38 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *SpecializationsLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+SpecializationsLookupTrait::internal_key_type
+SpecializationsLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+SpecializationsLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void SpecializationsLookupTrait::ReadDataInto(internal_key_type,
+                                              const unsigned char *d,
+                                              unsigned DataLen,
+                                              data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls = DataLen / 4; NumDecls; --NumDecls) {
+    uint32_t LocalID =
+        endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+    Val.insert(Reader.getGlobalDeclID(F, LocalID));
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1312,6 +1344,44 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   return false;
 }
 
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+                                    uint64_t Offset, Decl *D) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECIALIZATIONS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  D = D->getCanonicalDecl();
+  SpecializationsLookups[D].Table.add(
+      &M, Data, reader::SpecializationsLookupTrait(*this, M));
+
+  return false;
+}
+
 void ASTReader::Error(StringRef Msg) const {
   Error(diag::err_fe_pch_malformed, Msg);
   if (PP.getLangOpts().Modules && !Diags.isDiagnosticInFlight() &&
@@ -7545,13 +7615,13 @@ void ASTReader::CompleteRedeclChain(const Decl *D) {
   }
 
   if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
-    CTSD->getSpecializedTemplate()->LoadLazySpecializations();
+    const_cast<ClassTemplateSpecializationDecl *>(CTSD)->loadExternalRedecls();
   if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
-    VTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *FD = dyn_cast<FunctionDecl>(D)) {
-    if (auto *Template = FD->getPrimaryTemplate())
-      Template->LoadLazySpecializations();
-  }
+    const_cast<VarTemplateSpecializationDecl *>(VTSD)->loadExternalRedecls();
+  if (auto *FD = dyn_cast<FunctionDecl>(D))
+    if (auto *FDInfo = FD->getTemplateSpecializationInfo())
+      const_cast<FunctionTemplateSpecializationInfo *>(FDInfo)
+          ->loadExternalRedecls();
 }
 
 CXXCtorInitializer **
@@ -7980,6 +8050,42 @@ ASTReader::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return !Decls.empty();
 }
 
+unsigned
+clang::GetTemplateArgsStableHash(ArrayRef<TemplateArgument> TemplateArgs) {
+  // FIXME: ODR hashing may not be the best mechanism to hash the template
+  // arguments. ODR hashing is (or perhaps, should be) about determining whether
+  // two things are spelled the same way and have the same meaning (as required
+  // by the C++ ODR), whereas what we want here is whether they have the same
+  // meaning regardless of spelling. Maybe we can get away with reusing ODR
+  // hashing anyway, on the basis that any canonical, non-dependent template
+  // argument should have the same (invented) spelling in every translation
+  // unit, but it is not sure that's true in all cases. There may still be cases
+  // where the canonical type includes some aspect of "whatever we saw first",
+  // in which case the ODR hash can differ across translation units for
+  // non-dependent, canonical template arguments that are spelled differently
+  // but have the same meaning. But it is not easy to raise examples.
+  ODRHash Hasher;
+  Hasher.AddInteger(TemplateArgs.size());
+  for (const TemplateArgument &TemplateArg : TemplateArgs)
+    Hasher.AddTemplateArgument(TemplateArg);
+  return Hasher.CalculateHash();
+}
+
+void ASTReader::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  auto HashValue = GetTemplateArgsStableHash(TemplateArgs);
+
+  Deserializing LookupResults(this);
+  for (DeclID ID : It->second.Table.find(HashValue))
+    GetDecl(ID);
+}
+
 void ASTReader::completeVisibleDeclsMap(const DeclContext *DC) {
   if (!DC->hasExternalVisibleStorage())
     return;
@@ -8009,6 +8115,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::SpecializationsLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D) {
+  assert(D->isCanonicalDecl());
+  auto I = SpecializationsLookups.find(D);
+  return I == SpecializationsLookups.end() ? nullptr : &I->second;
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index a149d82153037f..7d0ca4bfc2cde8 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -274,9 +274,10 @@ namespace clang {
       // FIXME: We should avoid this pattern of getting the ASTContext.
       ASTContext &C = D->getASTContext();
 
-      auto *&LazySpecializations = D->getCommonPtr()->LazySpecializations;
+      auto *&ExternalSpecializations =
+          D->getCommonPtr()->ExternalSpecializations;
 
-      if (auto &Old = LazySpecializations) {
+      if (auto &Old = ExternalSpecializations) {
         IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0]);
         llvm::sort(IDs);
         IDs.erase(std::unique(IDs.begin(), IDs.end()), IDs.end());
@@ -286,7 +287,7 @@ namespace clang {
       *Result = IDs.size();
       std::copy(IDs.begin(), IDs.end(), Result + 1);
 
-      LazySpecializations = Result;
+      ExternalSpecializations = Result;
     }
 
     template <typename DeclT>
@@ -426,6 +427,9 @@ namespace clang {
 
     std::pair<uint64_t, uint64_t> VisitDeclContext(DeclContext *DC);
 
+    void ReadSpecializations(ModuleFile &M, Decl *D,
+                             llvm::BitstreamCursor &Cursor);
+
     template<typename T>
     RedeclarableResult VisitRedeclarable(Redeclarable<T> *D);
 
@@ -2431,10 +2435,14 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   mergeRedeclarableTemplate(D, Redecl);
 
   if (ThisDeclID == Redecl.getFirstID()) {
-    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all of
-    // the specializations.
+    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all
+    // of the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 
@@ -2463,6 +2471,10 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
     // the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 }
@@ -2566,6 +2578,9 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -2755,6 +2770,14 @@ ASTDeclReader::VisitDeclContext(DeclContext *DC) {
   return std::make_pair(LexicalOffset, VisibleOffset);
 }
 
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+                                        llvm::BitstreamCursor &DeclsCursor) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed = Reader.ReadSpecializations(M, DeclsCursor, Offset, D);
+  (void)Failed;
+  assert(!Failed);
+}
+
 template <typename T>
 ASTDeclReader::RedeclarableResult
 ASTDeclReader::VisitRedeclarable(Redeclarable<T> *D) {
@@ -3802,6 +3825,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECIALIZATIONS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4114,6 +4138,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
         ReadVisibleDeclContextStorage(*Loc.F, DeclsCursor, Offsets.second, ID))
       return nullptr;
   }
+
   assert(Record.getIdx() == Record.size());
 
   // Load any relevant update records.
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 25a46ddabcb707..9ac3c139815b08 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,86 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+/// Class that performs lookup to specialized decls.
+class SpecializationsLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<DeclID, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<DeclID> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(DeclID ID) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == ID)
+              return;
+          Data.push_back(ID);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(ID).second)
+        Data.push_back(ID);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit SpecializationsLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (DeclID ID : From)
+      To.insert(ID);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct SpecializationsLookupTable {
+  MultiOnDiskHashTable<SpecializationsLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 03bddfe0f5047d..e24950e2be63b5 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/LambdaCapture.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/ODRHash.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/RawCommentList.h"
 #include "clang/AST/TemplateName.h"
@@ -3930,6 +3931,147 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class SpecializationsLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<DeclID, 64> DeclIDs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into DeclIDs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit SpecializationsLookupTrait(ASTWriter &Writer) : Writer(Writer) {}
+
+  template <typename Col> data_type getData(Col &&C) {
+    unsigned Start = DeclIDs.size();
+    for (auto *D : C)
+      DeclIDs.push_back(Writer.GetDeclRef(getDeclForLocalLookup(
+          Writer.getLangOpts(), const_cast<NamedDecl *>(D))));
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  data_type
+  ImportData(const reader::SpecializationsLookupTrait::data_type &FromReader) {
+    unsigned Start = DeclIDs.size();
+    for (auto ID : FromReader)
+      DeclIDs.push_back(ID);
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = 4 * (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I)
+      LE.write<uint32_t>(DeclIDs[I]);
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  assert(!isa<ClassTemplatePartialSpecializationDecl>(Spec) &&
+         !isa<VarTemplatePartialSpecializationDecl>(Spec) &&
+         "We shouldn't see partial specializations here.");
+
+  if (auto *FD = dyn_cast<FunctionDecl>(Spec)) {
+    auto *FDInfo = FD->getTemplateSpecializationInfo();
+    assert(FDInfo);
+    return GetTemplateArgsStableHash(FDInfo->TemplateArguments->asArray());
+  }
+
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsStableHash(CTSD->getTemplateArgs().asArray());
+
+  if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsStableHash(VTSD->getTemplateArgs().asArray());
+
+  llvm_unreachable("Unimaged specialization kind?");
+}
+} // namespace
+
+uint64_t ASTWriter::WriteSpecializationsLookupTable(
+    const NamedDecl *D,
+    llvm::SmallVectorImpl<const NamedDecl *> &Specializations) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::SpecializationsLookupTrait,
+                                SpecializationsLookupTrait>
+      Generator;
+  SpecializationsLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+      SpecializationMaps;
+
+  for (auto *Specialization : Specializations) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+    auto Iter = SpecializationMaps.find(HashedValue);
+    if (Iter == SpecializationMaps.end())
+      Iter = SpecializationMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(Specialization);
+  }
+
+  for (auto Iter : SpecializationMaps)
+    Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+
+  auto *Lookups =
+      Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
+  llvm::SmallString<4096> LookupTable;
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+
+  RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
+  Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5080,7 +5222,7 @@ ASTFileSignature ASTWriter::WriteASTCore(Sema &SemaRef, StringRef isysroot,
 
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/ 6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index bb1f51786d2813..d508f56618f843 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -207,22 +207,26 @@ namespace clang {
       return std::nullopt;
     }
 
-    template<typename DeclTy>
-    void AddTemplateSpecializations(DeclTy *D) {
+    template <typename DeclTy>
+    void AddTemplateSpecializations(
+        DeclTy *D, llvm::SmallVectorImpl<const NamedDecl *> &OptionalSpecs) {
       auto *Common = D->getCommonPtr();
 
       // If we have any lazy specializations, and the external AST source is
       // our chained AST reader, we can just write out the DeclIDs. Otherwise,
       // we need to resolve them to actual declarations.
       if (Writer.Chain != Writer.Context->getExternalSource() &&
-          Common->LazySpecializations) {
-        D->LoadLazySpecializations();
-        assert(!Common->LazySpecializations);
+          Common->ExternalSpecializations) {
+        D->loadExternalSpecializations();
+        assert(!Common->ExternalSpecializations);
       }
 
-      ArrayRef<DeclID> LazySpecializations;
-      if (auto *LS = Common->LazySpecializations)
-        LazySpecializations = llvm::ArrayRef(LS + 1, LS[0]);
+      for (auto &Entry : Common->Specializations)
+        OptionalSpecs.push_back(getSpecializationDecl(Entry));
+
+      ArrayRef<DeclID> ExternalSpecializations;
+      if (auto *LS = Common->ExternalSpecializations)
+        ExternalSpecializations = llvm::ArrayRef(LS + 1, LS[0]);
 
       // Add a slot to the record for the number of specializations.
       unsigned I = Record.size();
@@ -230,17 +234,20 @@ namespace clang {
 
       // AddFirstDeclFromEachModule might trigger deserialization, invalidating
       // *Specializations iterators.
-      llvm::SmallVector<const Decl*, 16> Specs;
-      for (auto &Entry : Common->Specializations)
-        Specs.push_back(getSpecializationDecl(Entry));
+      //
+      // We need to load all the partial specializations at once if the template
+      // required. Since we can't know if a partial specializations will be
+      // needed before resolving a request to instantiate the template.
+      llvm::SmallVector<const Decl *, 16> PartialSpecs;
       for (auto &Entry : getPartialSpecializations(Common))
-        Specs.push_back(getSpecializationDecl(Entry));
+        PartialSpecs.push_back(getSpecializationDecl(Entry));
 
-      for (auto *D : Specs) {
+      for (auto *D : PartialSpecs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
         AddFirstDeclFromEachModule(D, /*IncludeLocal*/true);
       }
-      Record.append(LazySpecializations.begin(), LazySpecializations.end());
+      Record.append(ExternalSpecializations.begin(),
+                    ExternalSpecializations.end());
 
       // Update the size entry we added earlier.
       Record[I] = Record.size() - I - 1;
@@ -1670,8 +1677,16 @@ void ASTDeclWriter::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
 void ASTDeclWriter::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
   Code = serialization::DECL_CLASS_TEMPLATE;
 }
 
@@ -1730,8 +1745,17 @@ void ASTDeclWriter::VisitClassTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitVarTemplateDecl(VarTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_VAR_TEMPLATE;
 }
 
@@ -1791,8 +1815,17 @@ void ASTDeclWriter::VisitVarTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_FUNCTION_TEMPLATE;
 }
 
@@ -2657,6 +2690,11 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/test/Modules/odr_hash.cpp b/clang/test/Modules/odr_hash.cpp
index fa8b2c81ab46e1..52c1fb990700e6 100644
--- a/clang/test/Modules/odr_hash.cpp
+++ b/clang/test/Modules/odr_hash.cpp
@@ -3084,7 +3084,7 @@ struct S5 {
 };
 #else
 S5 s5;
-// expected-error at second.h:* {{'PointersAndReferences::S5::x' from module 'SecondModule' is not present in definition of 'PointersAndReferences::S5' in module 'FirstModule'}}
+// expected-error at first.h:* {{'PointersAndReferences::S5::x' from module 'FirstModule' is not present in definition of 'PointersAndReferences::S5' in module 'SecondModule'}}
 // expected-note at first.h:* {{declaration of 'x' does not match}}
 #endif
 
@@ -4021,7 +4021,7 @@ struct Valid {
 #else
 Invalid::L2<1>::L3<1> invalid;
 // expected-error at second.h:* {{'Types::InjectedClassName::Invalid::L2::L3::x' from module 'SecondModule' is not present in definition of 'L3<>' in module 'FirstModule'}}
-// expected-note at first.h:* {{declaration of 'x' does not match}}
+// expected-note at second.h:* {{declaration of 'x' does not match}}
 Valid::L2<1>::L3<1> valid;
 #endif
 }  // namespace InjectedClassName
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index 10d7de970c643d..6276bbb6d0cb4b 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 add_clang_unittest(SerializationTests
   ForceCheckFileInputTest.cpp
   InMemoryModuleCacheTest.cpp
+  LoadSpecLazily.cpp
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   SourceLocationEncodingTest.cpp
diff --git a/clang/unittests/Serialization/LoadSpecLazily.cpp b/clang/unittests/Serialization/LoadSpecLazily.cpp
new file mode 100644
index 00000000000000..03f3ff3f786555
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazily.cpp
@@ -0,0 +1,159 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + " .pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+class DeclsReaderListener : public ASTDeserializationListener {
+public:
+  void DeclRead(serialization::DeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    EXPECT_FALSE(ND->getName().contains(ForbiddenName));
+  }
+
+  DeclsReaderListener(StringRef ForbiddenName) : ForbiddenName(ForbiddenName) {}
+
+  StringRef ForbiddenName;
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef ForbiddenName) : Listener(ForbiddenName) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef ForbiddenName;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(ForbiddenName);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef ForbiddenName)
+      : ForbiddenName(ForbiddenName) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class ShouldNotBeLoaded {};
+
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      test_file_contents,
+      {
+          "-std=c++20",
+          DepArg.c_str(),
+          "-I",
+          TestDir.c_str(),
+      },
+      "test.cpp"));
+}
+
+} // namespace

>From 7f2897badc6b794380bf6b2352b535f7ee0e111c Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Thu, 4 Jan 2024 16:19:05 +0800
Subject: [PATCH 2/4] Load Specialization Updates Lazily

---
 .../include/clang/Serialization/ASTBitCodes.h |  2 +
 clang/include/clang/Serialization/ASTReader.h | 12 ++++--
 clang/include/clang/Serialization/ASTWriter.h |  8 ++++
 clang/lib/Serialization/ASTCommon.h           |  2 +-
 clang/lib/Serialization/ASTReader.cpp         | 29 ++++++++++---
 clang/lib/Serialization/ASTReaderDecl.cpp     | 41 ++++++++++++------
 clang/lib/Serialization/ASTWriter.cpp         | 43 ++++++++++++++++---
 clang/lib/Serialization/ASTWriterDecl.cpp     |  9 +++-
 clang/test/Modules/cxx-templates.cpp          |  9 ++--
 .../Serialization/LoadSpecLazily.cpp          | 34 +++++++++++++++
 10 files changed, 153 insertions(+), 36 deletions(-)

diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 23a279de96ab15..212ae7db30faa0 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -695,6 +695,8 @@ enum ASTRecordTypes {
   /// Record code for an unterminated \#pragma clang assume_nonnull begin
   /// recorded in a preamble.
   PP_ASSUME_NONNULL_LOC = 67,
+
+  UPDATE_SPECIALIZATION = 68,
 };
 
 /// Record types used within a source manager block.
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 293d6495d164ef..10726b440de515 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -610,18 +610,22 @@ class ASTReader
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
-  // will be realized when the context is actually loaded.
-  struct PendingVisibleUpdate {
+  // will be realized when the data is actually loaded.
+  struct UpdateData {
     ModuleFile *Mod;
     const unsigned char *Data;
   };
-  using DeclContextVisibleUpdates = SmallVector<PendingVisibleUpdate, 1>;
+  using DeclContextVisibleUpdates = SmallVector<UpdateData, 1>;
 
   /// Updates to the visible declarations of declaration contexts that
   /// haven't been loaded yet.
   llvm::DenseMap<serialization::DeclID, DeclContextVisibleUpdates>
       PendingVisibleUpdates;
 
+  using SpecializationsUpdate = SmallVector<UpdateData, 1>;
+  llvm::DenseMap<serialization::DeclID, SpecializationsUpdate>
+      PendingSpecializationsUpdates;
+
   /// The set of C++ or Objective-C classes that have forward
   /// declarations that have not yet been linked to their definitions.
   llvm::SmallPtrSet<Decl *, 4> PendingDefinitions;
@@ -650,6 +654,8 @@ class ASTReader
 
   bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
                            uint64_t Offset, Decl *D);
+  void AddSpecializations(const Decl *D, const unsigned char *Data,
+                          ModuleFile &M);
 
   /// A vector containing identifiers that have already been
   /// loaded.
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index 09806b87590766..0d39f8ace87843 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -385,6 +385,10 @@ class ASTWriter : public ASTDeserializationListener,
   /// record containing modifications to them.
   DeclUpdateMap DeclUpdates;
 
+  using SpecializationUpdateMap =
+      llvm::MapVector<const NamedDecl *, SmallVector<const NamedDecl *>>;
+  SpecializationUpdateMap SpecializationsUpdates;
+
   using FirstLatestDeclMap = llvm::DenseMap<Decl *, Decl *>;
 
   /// Map of first declarations from a chained PCH that point to the
@@ -527,6 +531,9 @@ class ASTWriter : public ASTDeserializationListener,
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
   bool isLookupResultEntirelyExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  void GenerateSpecializationsLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const NamedDecl *> &Specs,
+      llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteSpecializationsLookupTable(
       const NamedDecl *D,
       llvm::SmallVectorImpl<const NamedDecl *> &Specializations);
@@ -542,6 +549,7 @@ class ASTWriter : public ASTDeserializationListener,
   void WriteReferencedSelectorsPool(Sema &SemaRef);
   void WriteIdentifierTable(Preprocessor &PP, IdentifierResolver &IdResolver,
                             bool IsModule);
+  void WriteSpecializationsUpdates();
   void WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord);
   void WriteDeclContextVisibleUpdate(const DeclContext *DC);
   void WriteFPPragmaOptions(const FPOptionsOverride &Opts);
diff --git a/clang/lib/Serialization/ASTCommon.h b/clang/lib/Serialization/ASTCommon.h
index 296642e3674a49..9d190c05062444 100644
--- a/clang/lib/Serialization/ASTCommon.h
+++ b/clang/lib/Serialization/ASTCommon.h
@@ -23,7 +23,7 @@ namespace serialization {
 
 enum DeclUpdateKind {
   UPD_CXX_ADDED_IMPLICIT_MEMBER,
-  UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION,
+  UPD_CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION,
   UPD_CXX_ADDED_ANONYMOUS_NAMESPACE,
   UPD_CXX_ADDED_FUNCTION_DEFINITION,
   UPD_CXX_ADDED_VAR_DEFINITION,
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 2e280c52b6e59c..461bf09d2d16d3 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1340,10 +1340,17 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   // We can't safely determine the primary context yet, so delay attaching the
   // lookup table until we're done with recursive deserialization.
   auto *Data = (const unsigned char*)Blob.data();
-  PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&M, Data});
+  PendingVisibleUpdates[ID].push_back(UpdateData{&M, Data});
   return false;
 }
 
+void ASTReader::AddSpecializations(const Decl *D, const unsigned char *Data,
+                                   ModuleFile &M) {
+  D = D->getCanonicalDecl();
+  SpecializationsLookups[D].Table.add(
+      &M, Data, reader::SpecializationsLookupTrait(*this, M));
+}
+
 bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
                                     uint64_t Offset, Decl *D) {
   assert(Offset != 0);
@@ -1375,10 +1382,7 @@ bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
   }
 
   auto *Data = (const unsigned char *)Blob.data();
-  D = D->getCanonicalDecl();
-  SpecializationsLookups[D].Table.add(
-      &M, Data, reader::SpecializationsLookupTrait(*this, M));
-
+  AddSpecializations(D, Data, M);
   return false;
 }
 
@@ -3478,7 +3482,20 @@ llvm::Error ASTReader::ReadASTBlock(ModuleFile &F,
       unsigned Idx = 0;
       serialization::DeclID ID = ReadDeclID(F, Record, Idx);
       auto *Data = (const unsigned char*)Blob.data();
-      PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&F, Data});
+      PendingVisibleUpdates[ID].push_back(UpdateData{&F, Data});
+      // If we've already loaded the decl, perform the updates when we finish
+      // loading this block.
+      if (Decl *D = GetExistingDecl(ID))
+        PendingUpdateRecords.push_back(
+            PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+      break;
+    }
+
+    case UPDATE_SPECIALIZATION: {
+      unsigned Idx = 0;
+      serialization::DeclID ID = ReadDeclID(F, Record, Idx);
+      auto *Data = (const unsigned char *)Blob.data();
+      PendingSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
       // If we've already loaded the decl, perform the updates when we finish
       // loading this block.
       if (Decl *D = GetExistingDecl(ID))
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 7d0ca4bfc2cde8..d30948efb9289b 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -321,7 +321,9 @@ namespace clang {
     void ReadFunctionDefinition(FunctionDecl *FD);
     void Visit(Decl *D);
 
-    void UpdateDecl(Decl *D, SmallVectorImpl<serialization::DeclID> &);
+    void UpdateDecl(
+        Decl *D,
+        SmallVectorImpl<serialization::DeclID> &UpdatedPartialSpecializations);
 
     static void setNextObjCCategory(ObjCCategoryDecl *Cat,
                                     ObjCCategoryDecl *Next) {
@@ -4196,7 +4198,7 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
   ProcessingUpdatesRAIIObj ProcessingUpdates(*this);
   DeclUpdateOffsetsMap::iterator UpdI = DeclUpdateOffsets.find(ID);
 
-  SmallVector<serialization::DeclID, 8> PendingLazySpecializationIDs;
+  SmallVector<serialization::DeclID, 8> PendingLazyPartialSpecializationIDs;
 
   if (UpdI != DeclUpdateOffsets.end()) {
     auto UpdateOffsets = std::move(UpdI->second);
@@ -4235,7 +4237,7 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
 
       ASTDeclReader Reader(*this, Record, RecordLocation(F, Offset), ID,
                            SourceLocation());
-      Reader.UpdateDecl(D, PendingLazySpecializationIDs);
+      Reader.UpdateDecl(D, PendingLazyPartialSpecializationIDs);
 
       // We might have made this declaration interesting. If so, remember that
       // we need to hand it off to the consumer.
@@ -4248,16 +4250,16 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
     }
   }
   // Add the lazy specializations to the template.
-  assert((PendingLazySpecializationIDs.empty() || isa<ClassTemplateDecl>(D) ||
-          isa<FunctionTemplateDecl, VarTemplateDecl>(D)) &&
+  assert((PendingLazyPartialSpecializationIDs.empty() ||
+          isa<ClassTemplateDecl, VarTemplateDecl>(D)) &&
          "Must not have pending specializations");
   if (auto *CTD = dyn_cast<ClassTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(CTD, PendingLazySpecializationIDs);
-  else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(FTD, PendingLazySpecializationIDs);
+    ASTDeclReader::AddLazySpecializations(CTD,
+                                          PendingLazyPartialSpecializationIDs);
   else if (auto *VTD = dyn_cast<VarTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(VTD, PendingLazySpecializationIDs);
-  PendingLazySpecializationIDs.clear();
+    ASTDeclReader::AddLazySpecializations(VTD,
+                                          PendingLazyPartialSpecializationIDs);
+  PendingLazyPartialSpecializationIDs.clear();
 
   // Load the pending visible updates for this decl context, if it has any.
   auto I = PendingVisibleUpdates.find(ID);
@@ -4272,6 +4274,16 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
           reader::ASTDeclContextNameLookupTrait(*this, *Update.Mod));
     DC->setHasExternalVisibleStorage(true);
   }
+
+  // Load the pending specializations update for this decl, if it has any.
+  if (auto I = PendingSpecializationsUpdates.find(ID);
+      I != PendingSpecializationsUpdates.end()) {
+    auto SpecializationUpdates = std::move(I->second);
+    PendingSpecializationsUpdates.erase(I);
+
+    for (const auto &Update : SpecializationUpdates)
+      AddSpecializations(D, Update.Data, *Update.Mod);
+  }
 }
 
 void ASTReader::loadPendingDeclChain(Decl *FirstLocal, uint64_t LocalOffset) {
@@ -4466,7 +4478,8 @@ static void forAllLaterRedecls(DeclT *D, Fn F) {
 }
 
 void ASTDeclReader::UpdateDecl(Decl *D,
-   llvm::SmallVectorImpl<serialization::DeclID> &PendingLazySpecializationIDs) {
+                               llvm::SmallVectorImpl<serialization::DeclID>
+                                   &PendingLazyPartialSpecializationIDs) {
   while (Record.getIdx() < Record.size()) {
     switch ((DeclUpdateKind)Record.readInt()) {
     case UPD_CXX_ADDED_IMPLICIT_MEMBER: {
@@ -4477,9 +4490,9 @@ void ASTDeclReader::UpdateDecl(Decl *D,
       break;
     }
 
-    case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
-      // It will be added to the template's lazy specialization set.
-      PendingLazySpecializationIDs.push_back(readDeclID());
+    case UPD_CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION:
+      // It will be added to the template's lazy partial specialization set.
+      PendingLazyPartialSpecializationIDs.push_back(readDeclID());
       break;
 
     case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE: {
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index e24950e2be63b5..90053583ce67fe 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -4029,9 +4029,10 @@ unsigned CalculateODRHashForSpecs(const Decl *Spec) {
 }
 } // namespace
 
-uint64_t ASTWriter::WriteSpecializationsLookupTable(
+void ASTWriter::GenerateSpecializationsLookupTable(
     const NamedDecl *D,
-    llvm::SmallVectorImpl<const NamedDecl *> &Specializations) {
+    llvm::SmallVectorImpl<const NamedDecl *> &Specializations,
+    llvm::SmallVectorImpl<char> &LookupTable) {
   assert(D->isFirstDecl());
 
   // Create the on-disk hash table representation.
@@ -4059,13 +4060,19 @@ uint64_t ASTWriter::WriteSpecializationsLookupTable(
   for (auto Iter : SpecializationMaps)
     Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
 
-  uint64_t Offset = Stream.GetCurrentBitNo();
-
   auto *Lookups =
       Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
-  llvm::SmallString<4096> LookupTable;
   Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+}
+
+uint64_t ASTWriter::WriteSpecializationsLookupTable(
+    const NamedDecl *D,
+    llvm::SmallVectorImpl<const NamedDecl *> &Specializations) {
+
+  llvm::SmallString<4096> LookupTable;
+  GenerateSpecializationsLookupTable(D, Specializations, LookupTable);
 
+  uint64_t Offset = Stream.GetCurrentBitNo();
   RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
   Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
 
@@ -5245,6 +5252,10 @@ ASTFileSignature ASTWriter::WriteASTCore(Sema &SemaRef, StringRef isysroot,
   WriteTypeDeclOffsets();
   if (!DeclUpdatesOffsetsRecord.empty())
     Stream.EmitRecord(DECL_UPDATE_OFFSETS, DeclUpdatesOffsetsRecord);
+
+  if (!SpecializationsUpdates.empty())
+    WriteSpecializationsUpdates();
+
   WriteFileDeclIDsMap();
   WriteSourceManagerBlock(Context.getSourceManager(), PP);
   WriteComments();
@@ -5397,6 +5408,26 @@ ASTFileSignature ASTWriter::WriteASTCore(Sema &SemaRef, StringRef isysroot,
   return backpatchSignature();
 }
 
+void ASTWriter::WriteSpecializationsUpdates() {
+  auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
+  Abv->Add(llvm::BitCodeAbbrevOp(UPDATE_SPECIALIZATION));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::VBR, 6));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::Blob));
+  auto UpdateSpecializationAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  for (auto &SpecializationUpdate : SpecializationsUpdates) {
+    const NamedDecl *D = SpecializationUpdate.first;
+
+    llvm::SmallString<4096> LookupTable;
+    GenerateSpecializationsLookupTable(D, SpecializationUpdate.second,
+                                       LookupTable);
+
+    // Write the lookup table
+    RecordData::value_type Record[] = {UPDATE_SPECIALIZATION, getDeclID(D)};
+    Stream.EmitRecordWithBlob(UpdateSpecializationAbbrev, Record, LookupTable);
+  }
+}
+
 void ASTWriter::WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord) {
   if (DeclUpdates.empty())
     return;
@@ -5425,7 +5456,7 @@ void ASTWriter::WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord) {
 
       switch (Kind) {
       case UPD_CXX_ADDED_IMPLICIT_MEMBER:
-      case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
+      case UPD_CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION:
       case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE:
         assert(Update.getDecl() && "no decl to add?");
         Record.push_back(GetDeclRef(Update.getDecl()));
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index d508f56618f843..7cd6b8bd406f4f 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -271,8 +271,13 @@ namespace clang {
       if (Writer.getFirstLocalDecl(Specialization) != Specialization)
         return;
 
-      Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
-          UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION, Specialization));
+      if (isa<ClassTemplatePartialSpecializationDecl,
+              VarTemplatePartialSpecializationDecl>(Specialization))
+        Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
+            UPD_CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION, Specialization));
+      else
+        Writer.SpecializationsUpdates[cast<NamedDecl>(Template)].push_back(
+            cast<NamedDecl>(Specialization));
     }
   };
 }
diff --git a/clang/test/Modules/cxx-templates.cpp b/clang/test/Modules/cxx-templates.cpp
index b7d5741e69af61..2d285c10ceec59 100644
--- a/clang/test/Modules/cxx-templates.cpp
+++ b/clang/test/Modules/cxx-templates.cpp
@@ -251,7 +251,7 @@ namespace Std {
 
 // CHECK-DUMP:      ClassTemplateDecl {{.*}} <{{.*[/\\]}}cxx-templates-common.h:1:1, {{.*}}>  col:{{.*}} in cxx_templates_common SomeTemplate
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char[2]'
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char[1]'
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} SomeTemplate definition
 // CHECK-DUMP-NEXT:     DefinitionData
 // CHECK-DUMP-NEXT:       DefaultConstructor
@@ -260,9 +260,9 @@ namespace Std {
 // CHECK-DUMP-NEXT:       CopyAssignment
 // CHECK-DUMP-NEXT:       MoveAssignment
 // CHECK-DUMP-NEXT:       Destructor
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char[2]'
-// CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
 // CHECK-DUMP-NEXT:     TemplateArgument type 'char[1]'
+// CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char[2]'
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} SomeTemplate definition
 // CHECK-DUMP-NEXT:     DefinitionData
 // CHECK-DUMP-NEXT:       DefaultConstructor
@@ -271,4 +271,5 @@ namespace Std {
 // CHECK-DUMP-NEXT:       CopyAssignment
 // CHECK-DUMP-NEXT:       MoveAssignment
 // CHECK-DUMP-NEXT:       Destructor
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char[1]'
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char[2]'
+
diff --git a/clang/unittests/Serialization/LoadSpecLazily.cpp b/clang/unittests/Serialization/LoadSpecLazily.cpp
index 03f3ff3f786555..5a174f2335a4b9 100644
--- a/clang/unittests/Serialization/LoadSpecLazily.cpp
+++ b/clang/unittests/Serialization/LoadSpecLazily.cpp
@@ -156,4 +156,38 @@ A<int> a;
       "test.cpp"));
 }
 
+TEST_F(LoadSpecLazilyTest, ChainedTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldNotBeLoaded {};
+
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      test_file_contents,
+      {
+          "-std=c++20",
+          DepArg.c_str(),
+          "-I",
+          TestDir.c_str(),
+      },
+      "test.cpp"));
+}
+
 } // namespace

>From 2d987e3ece979f60390ebac6232f657a0b1aafe3 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Thu, 11 Jan 2024 16:55:42 +0800
Subject: [PATCH 3/4] Avoid calculating hashes for template argument as much as
 possible

Previously we will always try to load the specializations with the
corresponding arguments before finding the specializations. This
requires to hash the template arguments.

This patch tries to improve this by trying to load the specializations
only if we can't find it locally.
---
 clang/include/clang/AST/DeclTemplate.h        |  7 ++-
 clang/include/clang/AST/ExternalASTSource.h   |  2 +-
 .../clang/Sema/MultiplexExternalSemaSource.h  |  2 +-
 clang/include/clang/Serialization/ASTReader.h |  2 +-
 clang/lib/AST/DeclTemplate.cpp                | 60 +++++++++++++------
 clang/lib/AST/ExternalASTSource.cpp           |  4 +-
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |  6 +-
 clang/lib/Serialization/ASTReader.cpp         | 14 ++++-
 8 files changed, 68 insertions(+), 29 deletions(-)

diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 4699dd17bc182c..515c60e51e917e 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -799,7 +799,12 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                          void *&InsertPos, ProfileArguments &&...ProfileArgs);
 
-  void loadLazySpecializationsWithArgs(ArrayRef<TemplateArgument> TemplateArgs);
+  template <class EntryType, typename... ProfileArguments>
+  typename SpecEntryTraits<EntryType>::DeclType *
+  findLocalSpecialization(llvm::FoldingSetVector<EntryType> &Specs,
+                          void *&InsertPos, ProfileArguments &&... ProfileArgs);
+
+  bool loadLazySpecializationsWithArgs(ArrayRef<TemplateArgument> TemplateArgs);
 
   template <class Derived, class EntryType>
   void addSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 75ee6f827080d4..036247d87bc955 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -152,7 +152,7 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
 
   /// Load all the external specialzations for the Decl and the corresponding
   /// template arguments.
-  virtual void
+  virtual bool
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs);
 
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 9345e96e8cb8c3..24df0f73d518af 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -99,7 +99,7 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
 
   /// Load all the external specialzations for the Decl and the corresponding
   /// template args.
-  virtual void
+  virtual bool
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs) override;
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 10726b440de515..7c00ddfe024fa6 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -2004,7 +2004,7 @@ class ASTReader
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
-  void
+  bool
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs) override;
 
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 16b396a5e785d7..d62f8db4fde885 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -344,28 +344,63 @@ void RedeclarableTemplateDecl::loadExternalSpecializations() const {
   }
 }
 
-template<class EntryType, typename... ProfileArguments>
+template <class EntryType, typename... ProfileArguments>
 typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
-RedeclarableTemplateDecl::findSpecializationImpl(
+RedeclarableTemplateDecl::findLocalSpecialization(
     llvm::FoldingSetVector<EntryType> &Specs, void *&InsertPos,
-    ProfileArguments&&... ProfileArgs) {
+    ProfileArguments &&... ProfileArgs) {
   using SETraits = SpecEntryTraits<EntryType>;
 
   llvm::FoldingSetNodeID ID;
   EntryType::Profile(ID, std::forward<ProfileArguments>(ProfileArgs)...,
                      getASTContext());
   EntryType *Entry = Specs.FindNodeOrInsertPos(ID, InsertPos);
+
   return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsWithArgs(
+template <class EntryType, typename... ProfileArguments>
+typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
+RedeclarableTemplateDecl::findSpecializationImpl(
+    llvm::FoldingSetVector<EntryType> &Specs, void *&InsertPos,
+    ProfileArguments &&... ProfileArgs) {
+  if (auto *Ret = findLocalSpecialization(
+          Specs, InsertPos, std::forward<ProfileArguments>(ProfileArgs)...))
+    return Ret;
+
+  // If it is partial specialization, we are done.
+  if constexpr (std::is_same_v<EntryType,
+                               ClassTemplatePartialSpecializationDecl> ||
+                std::is_same_v<EntryType, VarTemplatePartialSpecializationDecl>)
+    return nullptr;
+  else {
+    // If we don't find the specialization, try to load it from the external
+    // sources.
+    static_assert(
+        !std::is_same_v<EntryType, ClassTemplatePartialSpecializationDecl> &&
+        !std::is_same_v<EntryType, VarTemplatePartialSpecializationDecl>);
+    static_assert(sizeof...(ProfileArguments) == 1);
+    using FirstArgument =
+        std::tuple_element_t<0, std::tuple<ProfileArguments...>>;
+    static_assert(std::is_same_v<std::remove_reference_t<FirstArgument>,
+                                 ArrayRef<TemplateArgument>>);
+    if (!loadLazySpecializationsWithArgs(
+            std::forward<ProfileArguments>(ProfileArgs)...))
+      return nullptr;
+
+    return findLocalSpecialization(
+        Specs, InsertPos, std::forward<ProfileArguments>(ProfileArgs)...);
+  }
+}
+
+bool RedeclarableTemplateDecl::loadLazySpecializationsWithArgs(
     ArrayRef<TemplateArgument> TemplateArgs) {
   auto *ExternalSource = getASTContext().getExternalSource();
   if (!ExternalSource)
-    return;
+    return false;
 
-  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
-                                              TemplateArgs);
+  return ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                                     TemplateArgs);
 }
 
 template<class Derived, class EntryType>
@@ -449,14 +484,11 @@ FunctionTemplateDecl::getSpecializations() const {
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
-  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
-  using SETraits = SpecEntryTraits<FunctionTemplateSpecializationInfo>;
-  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(Info));
   addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
                                               InsertPos);
 }
@@ -539,14 +571,11 @@ ClassTemplateDecl::newCommon(ASTContext &C) const {
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
-  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
-  using SETraits = SpecEntryTraits<ClassTemplateSpecializationDecl>;
-  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
@@ -554,7 +583,6 @@ ClassTemplatePartialSpecializationDecl *
 ClassTemplateDecl::findPartialSpecialization(
     ArrayRef<TemplateArgument> Args,
     TemplateParameterList *TPL, void *&InsertPos) {
-  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -1268,21 +1296,17 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
-  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
-  using SETraits = SpecEntryTraits<VarTemplateSpecializationDecl>;
-  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *
 VarTemplateDecl::findPartialSpecialization(ArrayRef<TemplateArgument> Args,
      TemplateParameterList *TPL, void *&InsertPos) {
-  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index bba73fcbd8fa23..05bdf0f1b22790 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -100,9 +100,9 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
-void ExternalASTSource::LoadExternalSpecializations(
+bool ExternalASTSource::LoadExternalSpecializations(
     const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
-  return;
+  return false;
 }
 
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index b845383675ab36..7bdd54d5c986bf 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,10 +115,12 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
-void MultiplexExternalSemaSource::LoadExternalSpecializations(
+bool MultiplexExternalSemaSource::LoadExternalSpecializations(
     const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  bool LoadedAnyDecls = false;
   for (size_t i = 0; i < Sources.size(); ++i)
-    Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+    LoadedAnyDecls |= Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+  return LoadedAnyDecls;
 }
 
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 461bf09d2d16d3..59132388e910a5 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -8088,19 +8088,27 @@ clang::GetTemplateArgsStableHash(ArrayRef<TemplateArgument> TemplateArgs) {
   return Hasher.CalculateHash();
 }
 
-void ASTReader::LoadExternalSpecializations(
+bool ASTReader::LoadExternalSpecializations(
     const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
   assert(D);
 
   auto It = SpecializationsLookups.find(D);
   if (It == SpecializationsLookups.end())
-    return;
+    return false;
 
   auto HashValue = GetTemplateArgsStableHash(TemplateArgs);
 
+  bool LoadedAnyDecls = false;
+
   Deserializing LookupResults(this);
-  for (DeclID ID : It->second.Table.find(HashValue))
+  for (DeclID ID : It->second.Table.find(HashValue)) {
+    if (GetExistingDecl(ID))
+      continue;
+    LoadedAnyDecls = true;
     GetDecl(ID);
+  }
+
+  return LoadedAnyDecls;
 }
 
 void ASTReader::completeVisibleDeclsMap(const DeclContext *DC) {

>From 70c788523d6676f0b14c0ffcc03dcdaacdb57766 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Mon, 22 Jan 2024 17:44:55 +0800
Subject: [PATCH 4/4] Introduce -fno-load-external-specializations-lazily

---
 clang/include/clang/AST/ExternalASTSource.h   |  3 +
 clang/include/clang/Basic/LangOptions.def     |  1 +
 clang/include/clang/Driver/Options.td         |  8 ++
 .../clang/Sema/MultiplexExternalSemaSource.h  |  2 +
 clang/include/clang/Serialization/ASTReader.h |  2 +
 clang/lib/AST/DeclTemplate.cpp                |  4 +
 clang/lib/AST/ExternalASTSource.cpp           |  2 +
 clang/lib/Driver/ToolChains/Clang.cpp         |  4 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |  5 ++
 clang/lib/Serialization/ASTReader.cpp         | 23 +++++-
 .../Serialization/LoadSpecLazily.cpp          | 79 ++++++++++++++++---
 11 files changed, 120 insertions(+), 13 deletions(-)

diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 036247d87bc955..e9bf592b10b8f0 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -156,6 +156,9 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs);
 
+  /// Load all the external specializations for the Decl D.
+  virtual void LoadAllExternalSpecializations(const Decl *D);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def
index f181da9aea6f97..2c020ac1941417 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -177,6 +177,7 @@ COMPATIBLE_LANGOPT(CPlusPlusModules, 1, 0, "C++ modules syntax")
 LANGOPT(BuiltinHeadersInSystemModules, 1, 0, "builtin headers belong to system modules, and _Builtin_ modules are ignored for cstdlib headers")
 BENIGN_ENUM_LANGOPT(CompilingModule, CompilingModuleKind, 3, CMK_None,
                     "compiling a module interface")
+COMPATIBLE_LANGOPT(LoadExternalSpecializationsLazily, 1, 1, "Load External Specializations Lazily")
 BENIGN_LANGOPT(CompilingPCH, 1, 0, "building a pch")
 BENIGN_LANGOPT(BuildingPCHWithObjectFile, 1, 0, "building a pch which has a corresponding object file")
 BENIGN_LANGOPT(CacheGeneratedPCH, 1, 0, "cache generated PCH files in memory")
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 199537404d0c63..a733812facfccb 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -2978,6 +2978,14 @@ defm prebuilt_implicit_modules : BoolFOption<"prebuilt-implicit-modules",
   PosFlag<SetTrue, [], [ClangOption], "Look up implicit modules in the prebuilt module path">,
   NegFlag<SetFalse>, BothFlags<[], [ClangOption, CC1Option]>>;
 
+defm load_external_specializations_lazily : BoolOption<"f", "load-external-specializations-lazily",
+  LangOpts<"LoadExternalSpecializationsLazily">, DefaultTrue,
+  PosFlag<SetTrue, [], []>,
+  NegFlag<SetFalse, [], [ClangOption, CC1Option],
+          "Load All External Specialization at once when required. "
+          "This is used when testing the functionality of LoadExternalSpecializationsLazily">>,
+  Group<f_clang_Group>;
+
 def fmodule_output_EQ : Joined<["-"], "fmodule-output=">,
   Flags<[NoXarchOption]>, Visibility<[ClangOption, CC1Option]>,
   HelpText<"Save intermediate module file results when compiling a standard C++ module unit.">;
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 24df0f73d518af..7ed04a8c265e61 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -103,6 +103,8 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs) override;
 
+  void LoadAllExternalSpecializations(const Decl *D) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 7c00ddfe024fa6..d722d57be5511b 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -2007,6 +2007,8 @@ class ASTReader
   bool
   LoadExternalSpecializations(const Decl *D,
                               ArrayRef<TemplateArgument> TemplateArgs) override;
+  
+  void LoadAllExternalSpecializations(const Decl *D) override;
 
   /// Read all of the declarations lexically stored in a
   /// declaration context.
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index d62f8db4fde885..4643fc39dcb5fc 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -342,6 +342,10 @@ void RedeclarableTemplateDecl::loadExternalSpecializations() const {
     for (uint32_t I = 0, N = *Specs++; I != N; ++I)
       (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
   }
+
+  if (!getASTContext().getLangOpts().LoadExternalSpecializationsLazily &&
+      getASTContext().getExternalSource())
+    getASTContext().getExternalSource()->LoadAllExternalSpecializations(this->getCanonicalDecl());
 }
 
 template <class EntryType, typename... ProfileArguments>
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index 05bdf0f1b22790..c0dab179dc8a52 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -105,6 +105,8 @@ bool ExternalASTSource::LoadExternalSpecializations(
   return false;
 }
 
+void ExternalASTSource::LoadAllExternalSpecializations(const Decl *D) {}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index fead2e884030e2..153e5770504c94 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -3930,6 +3930,10 @@ static bool RenderModulesOptions(Compilation &C, const Driver &D,
     Args.ClaimAllArgs(options::OPT_fmodules_disable_diagnostic_validation);
   }
 
+  if (Args.hasFlag(options::OPT_fno_load_external_specializations_lazily,
+                   options::OPT_fload_external_specializations_lazily, false))
+    CmdArgs.push_back("-fno-load-external-specializations-lazily");
+
   // Claim `-fmodule-output` and `-fmodule-output=` to avoid unused warnings.
   Args.ClaimAllArgs(options::OPT_fmodule_output);
   Args.ClaimAllArgs(options::OPT_fmodule_output_EQ);
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index 7bdd54d5c986bf..28be3e9e265d94 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -123,6 +123,11 @@ bool MultiplexExternalSemaSource::LoadExternalSpecializations(
   return LoadedAnyDecls;
 }
 
+void MultiplexExternalSemaSource::LoadAllExternalSpecializations(const Decl *D) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadAllExternalSpecializations(D);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 59132388e910a5..ffb6682808d9dd 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -8095,12 +8095,11 @@ bool ASTReader::LoadExternalSpecializations(
   auto It = SpecializationsLookups.find(D);
   if (It == SpecializationsLookups.end())
     return false;
-
+  
+  Deserializing LookupResults(this);
   auto HashValue = GetTemplateArgsStableHash(TemplateArgs);
 
   bool LoadedAnyDecls = false;
-
-  Deserializing LookupResults(this);
   for (DeclID ID : It->second.Table.find(HashValue)) {
     if (GetExistingDecl(ID))
       continue;
@@ -8111,6 +8110,24 @@ bool ASTReader::LoadExternalSpecializations(
   return LoadedAnyDecls;
 }
 
+void ASTReader::LoadAllExternalSpecializations(
+    const Decl *D) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+  
+  Deserializing LookupResults(this);
+
+  for (DeclID ID : It->second.Table.findAll())
+    GetDecl(ID);
+  
+  // Since we've loaded all the specializations, we can erase it from
+  // the lookup table.
+  SpecializationsLookups.erase(It);
+}
+
 void ASTReader::completeVisibleDeclsMap(const DeclContext *DC) {
   if (!DC->hasExternalVisibleStorage())
     return;
diff --git a/clang/unittests/Serialization/LoadSpecLazily.cpp b/clang/unittests/Serialization/LoadSpecLazily.cpp
index 5a174f2335a4b9..2c0285aa0488e7 100644
--- a/clang/unittests/Serialization/LoadSpecLazily.cpp
+++ b/clang/unittests/Serialization/LoadSpecLazily.cpp
@@ -87,26 +87,47 @@ class LoadSpecLazilyTest : public ::testing::Test {
   }
 };
 
+enum class CheckingMode {
+  Forbidden,
+  Required
+};
+
 class DeclsReaderListener : public ASTDeserializationListener {
+  StringRef SpeficiedName;
+  CheckingMode Mode;
+
+  bool ReadedSpecifiedName = false;
 public:
   void DeclRead(serialization::DeclID ID, const Decl *D) override {
     auto *ND = dyn_cast<NamedDecl>(D);
     if (!ND)
       return;
 
-    EXPECT_FALSE(ND->getName().contains(ForbiddenName));
+    ReadedSpecifiedName |= ND->getName().contains(SpeficiedName);
+    if (Mode == CheckingMode::Forbidden) {
+      EXPECT_FALSE(ReadedSpecifiedName);
+    }
   }
 
-  DeclsReaderListener(StringRef ForbiddenName) : ForbiddenName(ForbiddenName) {}
+  DeclsReaderListener(StringRef SpeficiedName, CheckingMode Mode)
+                      : SpeficiedName(SpeficiedName),
+                        Mode(Mode) {}
+
+  ~DeclsReaderListener() {
+    if (Mode == CheckingMode::Required) {
+      EXPECT_TRUE(ReadedSpecifiedName);
+    }
+  }
 
-  StringRef ForbiddenName;
 };
 
 class LoadSpecLazilyConsumer : public ASTConsumer {
   DeclsReaderListener Listener;
+  CheckingMode Mode;
 
 public:
-  LoadSpecLazilyConsumer(StringRef ForbiddenName) : Listener(ForbiddenName) {}
+  LoadSpecLazilyConsumer(StringRef SpecifiedName, CheckingMode Mode)
+    : Listener(SpecifiedName, Mode) {}
 
   ASTDeserializationListener *GetASTDeserializationListener() override {
     return &Listener;
@@ -114,16 +135,17 @@ class LoadSpecLazilyConsumer : public ASTConsumer {
 };
 
 class CheckLoadSpecLazilyAction : public ASTFrontendAction {
-  StringRef ForbiddenName;
+  StringRef SpecifiedName;
+  CheckingMode Mode;
 
 public:
   std::unique_ptr<ASTConsumer>
   CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
-    return std::make_unique<LoadSpecLazilyConsumer>(ForbiddenName);
+    return std::make_unique<LoadSpecLazilyConsumer>(SpecifiedName, Mode);
   }
 
-  CheckLoadSpecLazilyAction(StringRef ForbiddenName)
-      : ForbiddenName(ForbiddenName) {}
+  CheckLoadSpecLazilyAction(StringRef SpecifiedName, CheckingMode Mode)
+      : SpecifiedName(SpecifiedName), Mode(Mode) {}
 };
 
 TEST_F(LoadSpecLazilyTest, BasicTest) {
@@ -145,7 +167,7 @@ A<int> a;
   )cpp";
   std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
   EXPECT_TRUE(runToolOnCodeWithArgs(
-      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded", CheckingMode::Forbidden),
       test_file_contents,
       {
           "-std=c++20",
@@ -179,10 +201,47 @@ A<int> a;
   )cpp";
   std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
   EXPECT_TRUE(runToolOnCodeWithArgs(
-      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded", CheckingMode::Forbidden),
+      test_file_contents,
+      {
+          "-std=c++20",
+          DepArg.c_str(),
+          "-I",
+          TestDir.c_str(),
+      },
+      "test.cpp"));
+}
+
+// Checking that the option '-fno-load-external-specializations-lazily' can load
+// all specializations expectedly
+TEST_F(LoadSpecLazilyTest, LoadAllTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldBeLoaded {};
+
+export class Temp {
+   A<ShouldBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldBeLoaded", CheckingMode::Required),
       test_file_contents,
       {
           "-std=c++20",
+          "-fno-load-external-specializations-lazily",
           DepArg.c_str(),
           "-I",
           TestDir.c_str(),



More information about the cfe-commits mailing list