[clang] b5bd192 - [Serialization] Support load lazy specialization lazily

Chuanqi Xu via cfe-commits cfe-commits at lists.llvm.org
Thu Dec 5 18:54:26 PST 2024


Author: Chuanqi Xu
Date: 2024-12-06T10:52:35+08:00
New Revision: b5bd19211118c6d43bc525a4e3fb65d2c750d61e

URL: https://github.com/llvm/llvm-project/commit/b5bd19211118c6d43bc525a4e3fb65d2c750d61e
DIFF: https://github.com/llvm/llvm-project/commit/b5bd19211118c6d43bc525a4e3fb65d2c750d61e.diff

LOG: [Serialization] Support load lazy specialization lazily

Currently all the specializations of a template (including
instantiation, specialization and partial specializations)  will be
loaded at once if we want to instantiate another instance for the
template, or find instantiation for the template, or just want to
complete the redecl chain.

This means basically we need to load every specializations for the
template once the template declaration got loaded. This is bad since
when we load a specialization, we need to load all of its template
arguments. Then we have to deserialize a lot of unnecessary
declarations.

For example,

```
// M.cppm
export module M;
export template <class T>
class A {};

export class ShouldNotBeLoaded {};

export class Temp {
   A<ShouldNotBeLoaded> AS;
};

// use.cpp
import M;
A<int> a;
```

We should a specialization ` A<ShouldNotBeLoaded>` in `M.cppm` and we
instantiate the template `A` in `use.cpp`. Then we will deserialize
`ShouldNotBeLoaded` surprisingly when compiling `use.cpp`. And this
patch tries to avoid that.

Given that the templates are heavily used in C++, this is a pain point
for the performance.

This patch adds MultiOnDiskHashTable for specializations in the
ASTReader. Then we will only deserialize the specializations with the
same template arguments. We made that by using ODRHash for the template
arguments as the key of the hash table.

To review this patch, I think `ASTReaderDecl::AddLazySpecializations`
may be a good entry point.

The patch was reviewed in
https://github.com/llvm/llvm-project/pull/83237 but that PR is a stacked
PR. But I feel the intention of the stacked PRs get lost during the
review process. So I feel it is better to merge the commits into a
single commit instead of merging them in the PR page. It is better for
us to cherry-pick and revert.

Added: 
    clang/lib/Serialization/TemplateArgumentHasher.cpp
    clang/lib/Serialization/TemplateArgumentHasher.h
    clang/test/Modules/recursive-instantiations.cppm
    clang/unittests/Serialization/LoadSpecLazilyTest.cpp

Modified: 
    clang/include/clang/AST/DeclTemplate.h
    clang/include/clang/AST/ExternalASTSource.h
    clang/include/clang/Sema/MultiplexExternalSemaSource.h
    clang/include/clang/Serialization/ASTBitCodes.h
    clang/include/clang/Serialization/ASTReader.h
    clang/include/clang/Serialization/ASTWriter.h
    clang/lib/AST/DeclTemplate.cpp
    clang/lib/AST/ExternalASTSource.cpp
    clang/lib/AST/ODRHash.cpp
    clang/lib/Sema/MultiplexExternalSemaSource.cpp
    clang/lib/Serialization/ASTCommon.h
    clang/lib/Serialization/ASTReader.cpp
    clang/lib/Serialization/ASTReaderDecl.cpp
    clang/lib/Serialization/ASTReaderInternals.h
    clang/lib/Serialization/ASTWriter.cpp
    clang/lib/Serialization/ASTWriterDecl.cpp
    clang/lib/Serialization/CMakeLists.txt
    clang/test/Modules/odr_hash.cpp
    clang/test/OpenMP/target_parallel_ast_print.cpp
    clang/test/OpenMP/target_teams_ast_print.cpp
    clang/test/OpenMP/task_ast_print.cpp
    clang/test/OpenMP/teams_ast_print.cpp
    clang/unittests/Serialization/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index e4bf54c3d77b7f..dd92d40b804232 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -735,6 +735,7 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   }
 
   void anchor() override;
+
 protected:
   template <typename EntryType> struct SpecEntryTraits {
     using DeclType = EntryType;
@@ -775,13 +776,22 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     return SpecIterator<EntryType>(isEnd ? Specs.end() : Specs.begin());
   }
 
-  void loadLazySpecializationsImpl() const;
+  void loadLazySpecializationsImpl(bool OnlyPartial = false) const;
+
+  bool loadLazySpecializationsImpl(llvm::ArrayRef<TemplateArgument> Args,
+                                   TemplateParameterList *TPL = nullptr) const;
 
   template <class EntryType, typename ...ProfileArguments>
   typename SpecEntryTraits<EntryType>::DeclType*
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                          void *&InsertPos, ProfileArguments &&...ProfileArgs);
 
+  template <class EntryType, typename... ProfileArguments>
+  typename SpecEntryTraits<EntryType>::DeclType *
+  findSpecializationLocally(llvm::FoldingSetVector<EntryType> &Specs,
+                            void *&InsertPos,
+                            ProfileArguments &&...ProfileArgs);
+
   template <class Derived, class EntryType>
   void addSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                              EntryType *Entry, void *InsertPos);
@@ -796,13 +806,6 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     /// was explicitly specialized.
     llvm::PointerIntPair<RedeclarableTemplateDecl *, 1, bool>
         InstantiatedFromMember;
-
-    /// If non-null, points to an array of specializations (including
-    /// partial specializations) known only by their external declaration IDs.
-    ///
-    /// The first value in the array is the number of specializations/partial
-    /// specializations that follow.
-    GlobalDeclID *LazySpecializations = nullptr;
   };
 
   /// Pointer to the common data shared by all declarations of this
@@ -2283,7 +2286,7 @@ class ClassTemplateDecl : public RedeclarableTemplateDecl {
   friend class TemplateDeclInstantiator;
 
   /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  void LoadLazySpecializations(bool OnlyPartial = false) const;
 
   /// Get the underlying class declarations of the template.
   CXXRecordDecl *getTemplatedDecl() const {
@@ -3033,7 +3036,7 @@ class VarTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclWriter;
 
   /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  void LoadLazySpecializations(bool OnlyPartial = false) const;
 
   /// Get the underlying variable declarations of the template.
   VarDecl *getTemplatedDecl() const {

diff  --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 582ed7c65f58ca..9f968ba05b4466 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -152,6 +152,21 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specializations for the Decl \param D if \param
+  /// OnlyPartial is false. Otherwise, load all the external **partial**
+  /// specializations for the \param D.
+  ///
+  /// Return true if any new specializations get loaded. Return false otherwise.
+  virtual bool LoadExternalSpecializations(const Decl *D, bool OnlyPartial);
+
+  /// Load all the specializations for the Decl \param D with the same template
+  /// args specified by \param TemplateArgs.
+  ///
+  /// Return true if any new specializations get loaded. Return false otherwise.
+  virtual bool
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///

diff  --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 3d1906d8699265..0c92c52854c9e7 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  bool LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  bool
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;

diff  --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fd834c14ce790f..af0e08d800bf28 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -733,6 +733,13 @@ enum ASTRecordTypes {
   /// Record code for Sema's vector of functions/blocks with effects to
   /// be verified.
   DECLS_WITH_EFFECTS_TO_VERIFY = 72,
+
+  /// Record code for updated specialization
+  UPDATE_SPECIALIZATION = 73,
+
+  CXX_ADDED_TEMPLATE_SPECIALIZATION = 74,
+
+  CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION = 75,
 };
 
 /// Record types used within a source manager block.
@@ -1502,6 +1509,12 @@ enum DeclCode {
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
+  // A decls specilization record.
+  DECL_SPECIALIZATIONS,
+
+  // A decls specilization record.
+  DECL_PARTIAL_SPECIALIZATIONS,
+
   DECL_LAST = DECL_IMPLICIT_CONCEPT_SPECIALIZATION
 };
 

diff  --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index f739fe688c110d..f91052be5e1291 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -354,6 +354,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct LazySpecializationInfoLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -632,20 +635,40 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  using SpecLookupTableTy =
+      llvm::DenseMap<const Decl *,
+                     serialization::reader::LazySpecializationInfoLookupTable>;
+  /// Map from decls to specialized decls.
+  SpecLookupTableTy SpecializationsLookups;
+  /// Split partial specialization from specialization to speed up lookups.
+  SpecLookupTableTy PartialSpecializationsLookups;
+
+  bool LoadExternalSpecializationsImpl(SpecLookupTableTy &SpecLookups,
+                                       const Decl *D);
+  bool LoadExternalSpecializationsImpl(SpecLookupTableTy &SpecLookups,
+                                       const Decl *D,
+                                       ArrayRef<TemplateArgument> TemplateArgs);
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
-  // will be realized when the context is actually loaded.
-  struct PendingVisibleUpdate {
+  // will be realized when the data is actually loaded.
+  struct UpdateData {
     ModuleFile *Mod;
     const unsigned char *Data;
   };
-  using DeclContextVisibleUpdates = SmallVector<PendingVisibleUpdate, 1>;
+  using DeclContextVisibleUpdates = SmallVector<UpdateData, 1>;
 
   /// Updates to the visible declarations of declaration contexts that
   /// haven't been loaded yet.
   llvm::DenseMap<GlobalDeclID, DeclContextVisibleUpdates> PendingVisibleUpdates;
 
+  using SpecializationsUpdate = SmallVector<UpdateData, 1>;
+  using SpecializationsUpdateMap =
+      llvm::DenseMap<GlobalDeclID, SpecializationsUpdate>;
+  SpecializationsUpdateMap PendingSpecializationsUpdates;
+  SpecializationsUpdateMap PendingPartialSpecializationsUpdates;
+
   /// The set of C++ or Objective-C classes that have forward
   /// declarations that have not yet been linked to their definitions.
   llvm::SmallPtrSet<Decl *, 4> PendingDefinitions;
@@ -678,6 +701,11 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, GlobalDeclID ID);
 
+  bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                           uint64_t Offset, Decl *D, bool IsPartial);
+  void AddSpecializations(const Decl *D, const unsigned char *Data,
+                          ModuleFile &M, bool IsPartial);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1419,6 +1447,14 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::LazySpecializationInfoLookupTable *
+  getLoadedSpecializationsLookupTables(const Decl *D, bool IsPartial);
+
+  /// If we have any unloaded specialization for \p D
+  bool haveUnloadedSpecializations(const Decl *D) const;
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -2076,6 +2112,12 @@ class ASTReader
                                       unsigned BlockID,
                                       uint64_t *StartOfBlockOffset = nullptr);
 
+  bool LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  bool
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Finds all the visible declarations with a given name.
   /// The current implementation of this method just loads the entire
   /// lookup table as unmaterialized references.

diff  --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index e418fdea44a0a9..d98d23decbdc0d 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -423,6 +423,13 @@ class ASTWriter : public ASTDeserializationListener,
   /// Only meaningful for reduced BMI.
   DeclUpdateMap DeclUpdatesFromGMF;
 
+  /// Mapping from decl templates and its new specialization in the
+  /// current TU.
+  using SpecializationUpdateMap =
+      llvm::MapVector<const NamedDecl *, SmallVector<const Decl *>>;
+  SpecializationUpdateMap SpecializationsUpdates;
+  SpecializationUpdateMap PartialSpecializationsUpdates;
+
   using FirstLatestDeclMap = llvm::DenseMap<Decl *, Decl *>;
 
   /// Map of first declarations from a chained PCH that point to the
@@ -575,6 +582,12 @@ class ASTWriter : public ASTDeserializationListener,
 
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  void GenerateSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+      llvm::SmallVectorImpl<char> &LookupTable, bool IsPartial);
+  uint64_t WriteSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+      bool IsPartial);
   void GenerateNameLookupTable(ASTContext &Context, const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context,
@@ -590,6 +603,7 @@ class ASTWriter : public ASTDeserializationListener,
   void WriteDeclAndTypes(ASTContext &Context);
   void PrepareWritingSpecialDecls(Sema &SemaRef);
   void WriteSpecialDeclRecords(Sema &SemaRef);
+  void WriteSpecializationsUpdates(bool IsPartial);
   void WriteDeclUpdatesBlocks(ASTContext &Context,
                               RecordDataImpl &OffsetsRecord);
   void WriteDeclContextVisibleUpdate(ASTContext &Context,
@@ -619,6 +633,9 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecializationsAbbrev = 0;
+  unsigned DeclPartialSpecializationsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;

diff  --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 1da3f26bf23cd5..40ee3753c24227 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -16,7 +16,9 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclarationName.h"
 #include "clang/AST/Expr.h"
+#include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExternalASTSource.h"
+#include "clang/AST/ODRHash.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
 #include "clang/AST/Type.h"
@@ -348,26 +350,39 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
   return Common;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
-  // Grab the most recent declaration to ensure we've loaded any lazy
-  // redeclarations of this template.
-  CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
-  if (CommonBasePtr->LazySpecializations) {
-    ASTContext &Context = getASTContext();
-    GlobalDeclID *Specs = CommonBasePtr->LazySpecializations;
-    CommonBasePtr->LazySpecializations = nullptr;
-    unsigned SpecSize = (*Specs++).getRawValue();
-    for (unsigned I = 0; I != SpecSize; ++I)
-      (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
-  }
+void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
+    bool OnlyPartial /*=false*/) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                              OnlyPartial);
+  return;
 }
 
-template<class EntryType, typename... ProfileArguments>
+bool RedeclarableTemplateDecl::loadLazySpecializationsImpl(
+    ArrayRef<TemplateArgument> Args, TemplateParameterList *TPL) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return false;
+
+  // If TPL is not null, it implies that we're loading specializations for
+  // partial templates. We need to load all specializations in such cases.
+  if (TPL)
+    return ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                                       /*OnlyPartial=*/false);
+
+  return ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                                     Args);
+}
+
+template <class EntryType, typename... ProfileArguments>
 typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
-RedeclarableTemplateDecl::findSpecializationImpl(
+RedeclarableTemplateDecl::findSpecializationLocally(
     llvm::FoldingSetVector<EntryType> &Specs, void *&InsertPos,
-    ProfileArguments&&... ProfileArgs) {
-  using SETraits = SpecEntryTraits<EntryType>;
+    ProfileArguments &&...ProfileArgs) {
+  using SETraits = RedeclarableTemplateDecl::SpecEntryTraits<EntryType>;
 
   llvm::FoldingSetNodeID ID;
   EntryType::Profile(ID, std::forward<ProfileArguments>(ProfileArgs)...,
@@ -376,6 +391,24 @@ RedeclarableTemplateDecl::findSpecializationImpl(
   return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
 }
 
+template <class EntryType, typename... ProfileArguments>
+typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
+RedeclarableTemplateDecl::findSpecializationImpl(
+    llvm::FoldingSetVector<EntryType> &Specs, void *&InsertPos,
+    ProfileArguments &&...ProfileArgs) {
+
+  if (auto *Found = findSpecializationLocally(
+          Specs, InsertPos, std::forward<ProfileArguments>(ProfileArgs)...))
+    return Found;
+
+  if (!loadLazySpecializationsImpl(
+          std::forward<ProfileArguments>(ProfileArgs)...))
+    return nullptr;
+
+  return findSpecializationLocally(
+      Specs, InsertPos, std::forward<ProfileArguments>(ProfileArgs)...);
+}
+
 template<class Derived, class EntryType>
 void RedeclarableTemplateDecl::addSpecializationImpl(
     llvm::FoldingSetVector<EntryType> &Specializations, EntryType *Entry,
@@ -384,10 +417,14 @@ void RedeclarableTemplateDecl::addSpecializationImpl(
 
   if (InsertPos) {
 #ifndef NDEBUG
+    auto Args = SETraits::getTemplateArgs(Entry);
+    // Due to hash collisions, it can happen that we load another template
+    // specialization with the same hash. This is fine, as long as the next
+    // call to findSpecializationImpl does not find a matching Decl for the
+    // template arguments.
+    loadLazySpecializationsImpl(Args);
     void *CorrectInsertPos;
-    assert(!findSpecializationImpl(Specializations,
-                                   CorrectInsertPos,
-                                   SETraits::getTemplateArgs(Entry)) &&
+    assert(!findSpecializationImpl(Specializations, CorrectInsertPos, Args) &&
            InsertPos == CorrectInsertPos &&
            "given incorrect InsertPos for specialization");
 #endif
@@ -445,12 +482,14 @@ FunctionTemplateDecl::getSpecializations() const {
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), InsertPos, Args);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, InsertPos, Args);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
-  addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<FunctionTemplateDecl>(Common->Specializations, Info,
                                               InsertPos);
 }
 
@@ -510,8 +549,9 @@ ClassTemplateDecl *ClassTemplateDecl::CreateDeserialized(ASTContext &C,
                                        DeclarationName(), nullptr, nullptr);
 }
 
-void ClassTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
+void ClassTemplateDecl::LoadLazySpecializations(
+    bool OnlyPartial /*=false*/) const {
+  loadLazySpecializationsImpl(OnlyPartial);
 }
 
 llvm::FoldingSetVector<ClassTemplateSpecializationDecl> &
@@ -522,7 +562,7 @@ ClassTemplateDecl::getSpecializations() const {
 
 llvm::FoldingSetVector<ClassTemplatePartialSpecializationDecl> &
 ClassTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  LoadLazySpecializations(/*PartialOnly = */ true);
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -536,12 +576,15 @@ ClassTemplateDecl::newCommon(ASTContext &C) const {
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), InsertPos, Args);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, InsertPos, Args);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
-  addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<ClassTemplateDecl>(Common->Specializations, D,
+                                           InsertPos);
 }
 
 ClassTemplatePartialSpecializationDecl *
@@ -1259,8 +1302,9 @@ VarTemplateDecl *VarTemplateDecl::CreateDeserialized(ASTContext &C,
                                      DeclarationName(), nullptr, nullptr);
 }
 
-void VarTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
+void VarTemplateDecl::LoadLazySpecializations(
+    bool OnlyPartial /*=false*/) const {
+  loadLazySpecializationsImpl(OnlyPartial);
 }
 
 llvm::FoldingSetVector<VarTemplateSpecializationDecl> &
@@ -1271,7 +1315,7 @@ VarTemplateDecl::getSpecializations() const {
 
 llvm::FoldingSetVector<VarTemplatePartialSpecializationDecl> &
 VarTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  LoadLazySpecializations(/*PartialOnly = */ true);
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -1285,12 +1329,14 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), InsertPos, Args);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, InsertPos, Args);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
-  addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<VarTemplateDecl>(Common->Specializations, D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *

diff  --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index 7a14cc7d50ed05..543846c0093af8 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -96,6 +96,15 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+bool ExternalASTSource::LoadExternalSpecializations(const Decl *D, bool) {
+  return false;
+}
+
+bool ExternalASTSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument>) {
+  return false;
+}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(

diff  --git a/clang/lib/AST/ODRHash.cpp b/clang/lib/AST/ODRHash.cpp
index 645ca6f0e7b715..7c5c287e6c15ba 100644
--- a/clang/lib/AST/ODRHash.cpp
+++ b/clang/lib/AST/ODRHash.cpp
@@ -818,15 +818,20 @@ void ODRHash::AddDecl(const Decl *D) {
 
   AddDeclarationName(ND->getDeclName());
 
-  const auto *Specialization =
-            dyn_cast<ClassTemplateSpecializationDecl>(D);
-  AddBoolean(Specialization);
-  if (Specialization) {
-    const TemplateArgumentList &List = Specialization->getTemplateArgs();
-    ID.AddInteger(List.size());
-    for (const TemplateArgument &TA : List.asArray())
-      AddTemplateArgument(TA);
-  }
+  // If this was a specialization we should take into account its template
+  // arguments. This helps to reduce collisions coming when visiting template
+  // specialization types (eg. when processing type template arguments).
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
+    Args = CTSD->getTemplateArgs().asArray();
+  else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
+    Args = VTSD->getTemplateArgs().asArray();
+  else if (auto *FD = dyn_cast<FunctionDecl>(D))
+    if (FD->getTemplateSpecializationArgs())
+      Args = FD->getTemplateSpecializationArgs()->asArray();
+
+  for (auto &TA : Args)
+    AddTemplateArgument(TA);
 }
 
 namespace {

diff  --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index cd44483b5cbe04..54944267b4868a 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,23 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+bool MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, bool OnlyPartial) {
+  bool Loaded = false;
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Loaded |= Sources[i]->LoadExternalSpecializations(D, OnlyPartial);
+  return Loaded;
+}
+
+bool MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  bool AnyNewSpecsLoaded = false;
+  for (size_t i = 0; i < Sources.size(); ++i)
+    AnyNewSpecsLoaded |=
+        Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+  return AnyNewSpecsLoaded;
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);

diff  --git a/clang/lib/Serialization/ASTCommon.h b/clang/lib/Serialization/ASTCommon.h
index 2a765eafe08951..7c9ec884ea049d 100644
--- a/clang/lib/Serialization/ASTCommon.h
+++ b/clang/lib/Serialization/ASTCommon.h
@@ -24,7 +24,6 @@ namespace serialization {
 
 enum DeclUpdateKind {
   UPD_CXX_ADDED_IMPLICIT_MEMBER,
-  UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION,
   UPD_CXX_ADDED_ANONYMOUS_NAMESPACE,
   UPD_CXX_ADDED_FUNCTION_DEFINITION,
   UPD_CXX_ADDED_VAR_DEFINITION,

diff  --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index ec85fad3389a1c..490c690189c8ac 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -12,6 +12,7 @@
 
 #include "ASTCommon.h"
 #include "ASTReaderInternals.h"
+#include "TemplateArgumentHasher.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTMutationListener.h"
@@ -1295,6 +1296,42 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *
+LazySpecializationInfoLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+LazySpecializationInfoLookupTrait::internal_key_type
+LazySpecializationInfoLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+LazySpecializationInfoLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void LazySpecializationInfoLookupTrait::ReadDataInto(internal_key_type,
+                                                     const unsigned char *d,
+                                                     unsigned DataLen,
+                                                     data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls =
+           DataLen / sizeof(serialization::reader::LazySpecializationInfo);
+       NumDecls; --NumDecls) {
+    LocalDeclID LocalID = LocalDeclID::get(
+        Reader, F,
+        endian::readNext<DeclID, llvm::endianness::little, unaligned>(d));
+    Val.insert(Reader.getGlobalDeclID(F, LocalID));
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1379,7 +1416,52 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   // We can't safely determine the primary context yet, so delay attaching the
   // lookup table until we're done with recursive deserialization.
   auto *Data = (const unsigned char*)Blob.data();
-  PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&M, Data});
+  PendingVisibleUpdates[ID].push_back(UpdateData{&M, Data});
+  return false;
+}
+
+void ASTReader::AddSpecializations(const Decl *D, const unsigned char *Data,
+                                   ModuleFile &M, bool IsPartial) {
+  D = D->getCanonicalDecl();
+  auto &SpecLookups =
+      IsPartial ? PartialSpecializationsLookups : SpecializationsLookups;
+  SpecLookups[D].Table.add(&M, Data,
+                           reader::LazySpecializationInfoLookupTrait(*this, M));
+}
+
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+                                    uint64_t Offset, Decl *D, bool IsPartial) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECIALIZATIONS &&
+      RecCode != DECL_PARTIAL_SPECIALIZATIONS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  AddSpecializations(D, Data, M, IsPartial);
   return false;
 }
 
@@ -3458,7 +3540,33 @@ llvm::Error ASTReader::ReadASTBlock(ModuleFile &F,
       unsigned Idx = 0;
       GlobalDeclID ID = ReadDeclID(F, Record, Idx);
       auto *Data = (const unsigned char*)Blob.data();
-      PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&F, Data});
+      PendingVisibleUpdates[ID].push_back(UpdateData{&F, Data});
+      // If we've already loaded the decl, perform the updates when we finish
+      // loading this block.
+      if (Decl *D = GetExistingDecl(ID))
+        PendingUpdateRecords.push_back(
+            PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+      break;
+    }
+
+    case CXX_ADDED_TEMPLATE_SPECIALIZATION: {
+      unsigned Idx = 0;
+      GlobalDeclID ID = ReadDeclID(F, Record, Idx);
+      auto *Data = (const unsigned char *)Blob.data();
+      PendingSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
+      // If we've already loaded the decl, perform the updates when we finish
+      // loading this block.
+      if (Decl *D = GetExistingDecl(ID))
+        PendingUpdateRecords.push_back(
+            PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+      break;
+    }
+
+    case CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION: {
+      unsigned Idx = 0;
+      GlobalDeclID ID = ReadDeclID(F, Record, Idx);
+      auto *Data = (const unsigned char *)Blob.data();
+      PendingPartialSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
       // If we've already loaded the decl, perform the updates when we finish
       // loading this block.
       if (Decl *D = GetExistingDecl(ID))
@@ -7654,13 +7762,28 @@ void ASTReader::CompleteRedeclChain(const Decl *D) {
     }
   }
 
-  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
-    CTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
-    VTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *FD = dyn_cast<FunctionDecl>(D)) {
-    if (auto *Template = FD->getPrimaryTemplate())
-      Template->LoadLazySpecializations();
+  RedeclarableTemplateDecl *Template = nullptr;
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D)) {
+    Template = CTSD->getSpecializedTemplate();
+    Args = CTSD->getTemplateArgs().asArray();
+  } else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D)) {
+    Template = VTSD->getSpecializedTemplate();
+    Args = VTSD->getTemplateArgs().asArray();
+  } else if (auto *FD = dyn_cast<FunctionDecl>(D)) {
+    if (auto *Tmplt = FD->getPrimaryTemplate()) {
+      Template = Tmplt;
+      Args = FD->getTemplateSpecializationArgs()->asArray();
+    }
+  }
+
+  if (Template) {
+    // For partitial specialization, load all the specializations for safety.
+    if (isa<ClassTemplatePartialSpecializationDecl,
+            VarTemplatePartialSpecializationDecl>(D))
+      Template->loadLazySpecializationsImpl();
+    else
+      Template->loadLazySpecializationsImpl(Args);
   }
 }
 
@@ -8042,6 +8165,86 @@ Stmt *ASTReader::GetExternalDeclStmt(uint64_t Offset) {
   return ReadStmtFromStream(*Loc.F);
 }
 
+bool ASTReader::LoadExternalSpecializationsImpl(SpecLookupTableTy &SpecLookups,
+                                                const Decl *D) {
+  assert(D);
+
+  auto It = SpecLookups.find(D);
+  if (It == SpecLookups.end())
+    return false;
+
+  // Get Decl may violate the iterator from SpecializationsLookups so we store
+  // the DeclIDs in ahead.
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.findAll();
+
+  // Since we've loaded all the specializations, we can erase it from
+  // the lookup table.
+  SpecLookups.erase(It);
+
+  bool NewSpecsFound = false;
+  Deserializing LookupResults(this);
+  for (auto &Info : Infos) {
+    if (GetExistingDecl(Info))
+      continue;
+    NewSpecsFound = true;
+    GetDecl(Info);
+  }
+
+  return NewSpecsFound;
+}
+
+bool ASTReader::LoadExternalSpecializations(const Decl *D, bool OnlyPartial) {
+  assert(D);
+
+  bool NewSpecsFound =
+      LoadExternalSpecializationsImpl(PartialSpecializationsLookups, D);
+  if (OnlyPartial)
+    return NewSpecsFound;
+
+  NewSpecsFound |= LoadExternalSpecializationsImpl(SpecializationsLookups, D);
+  return NewSpecsFound;
+}
+
+bool ASTReader::LoadExternalSpecializationsImpl(
+    SpecLookupTableTy &SpecLookups, const Decl *D,
+    ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecLookups.find(D);
+  if (It == SpecLookups.end())
+    return false;
+
+  Deserializing LookupResults(this);
+  auto HashValue = StableHashForTemplateArguments(TemplateArgs);
+
+  // Get Decl may violate the iterator from SpecLookups
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.find(HashValue);
+
+  bool NewSpecsFound = false;
+  for (auto &Info : Infos) {
+    if (GetExistingDecl(Info))
+      continue;
+    NewSpecsFound = true;
+    GetDecl(Info);
+  }
+
+  return NewSpecsFound;
+}
+
+bool ASTReader::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  bool NewDeclsFound = LoadExternalSpecializationsImpl(
+      PartialSpecializationsLookups, D, TemplateArgs);
+  NewDeclsFound |=
+      LoadExternalSpecializationsImpl(SpecializationsLookups, D, TemplateArgs);
+
+  return NewDeclsFound;
+}
+
 void ASTReader::FindExternalLexicalDecls(
     const DeclContext *DC, llvm::function_ref<bool(Decl::Kind)> IsKindWeWant,
     SmallVectorImpl<Decl *> &Decls) {
@@ -8221,6 +8424,22 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::LazySpecializationInfoLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D, bool IsPartial) {
+  assert(D->isCanonicalDecl());
+  auto &LookupTable =
+      IsPartial ? PartialSpecializationsLookups : SpecializationsLookups;
+  auto I = LookupTable.find(D);
+  return I == LookupTable.end() ? nullptr : &I->second;
+}
+
+bool ASTReader::haveUnloadedSpecializations(const Decl *D) const {
+  assert(D->isCanonicalDecl());
+  return (PartialSpecializationsLookups.find(D) !=
+          PartialSpecializationsLookups.end()) ||
+         (SpecializationsLookups.find(D) != SpecializationsLookups.end());
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the

diff  --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 6ece3ba7af9f4b..0644ff4dfe6827 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -187,11 +187,6 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
 
   std::string readString() { return Record.readString(); }
 
-  void readDeclIDList(SmallVectorImpl<GlobalDeclID> &IDs) {
-    for (unsigned I = 0, Size = Record.readInt(); I != Size; ++I)
-      IDs.push_back(readDeclID());
-  }
-
   Decl *readDecl() { return Record.readDecl(); }
 
   template <typename T> T *readDeclAs() { return Record.readDeclAs<T>(); }
@@ -284,30 +279,6 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
       : Reader(Reader), MergeImpl(Reader), Record(Record), Loc(Loc),
         ThisDeclID(thisDeclID), ThisDeclLoc(ThisDeclLoc) {}
 
-  template <typename T>
-  static void AddLazySpecializations(T *D, SmallVectorImpl<GlobalDeclID> &IDs) {
-    if (IDs.empty())
-      return;
-
-    // FIXME: We should avoid this pattern of getting the ASTContext.
-    ASTContext &C = D->getASTContext();
-
-    auto *&LazySpecializations = D->getCommonPtr()->LazySpecializations;
-
-    if (auto &Old = LazySpecializations) {
-      IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0].getRawValue());
-      llvm::sort(IDs);
-      IDs.erase(std::unique(IDs.begin(), IDs.end()), IDs.end());
-    }
-
-    auto *Result = new (C) GlobalDeclID[1 + IDs.size()];
-    *Result = GlobalDeclID(IDs.size());
-
-    std::copy(IDs.begin(), IDs.end(), Result + 1);
-
-    LazySpecializations = Result;
-  }
-
   template <typename DeclT>
   static Decl *getMostRecentDeclImpl(Redeclarable<DeclT> *D);
   static Decl *getMostRecentDeclImpl(...);
@@ -332,10 +303,13 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
   static void markIncompleteDeclChainImpl(Redeclarable<DeclT> *D);
   static void markIncompleteDeclChainImpl(...);
 
+  void ReadSpecializations(ModuleFile &M, Decl *D,
+                           llvm::BitstreamCursor &DeclsCursor, bool IsPartial);
+
   void ReadFunctionDefinition(FunctionDecl *FD);
   void Visit(Decl *D);
 
-  void UpdateDecl(Decl *D, SmallVectorImpl<GlobalDeclID> &);
+  void UpdateDecl(Decl *D);
 
   static void setNextObjCCategory(ObjCCategoryDecl *Cat,
                                   ObjCCategoryDecl *Next) {
@@ -2418,6 +2392,16 @@ void ASTDeclReader::VisitImplicitConceptSpecializationDecl(
 void ASTDeclReader::VisitRequiresExprBodyDecl(RequiresExprBodyDecl *D) {
 }
 
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+                                        llvm::BitstreamCursor &DeclsCursor,
+                                        bool IsPartial) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed =
+      Reader.ReadSpecializations(M, DeclsCursor, Offset, D, IsPartial);
+  (void)Failed;
+  assert(!Failed);
+}
+
 RedeclarableResult
 ASTDeclReader::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
   RedeclarableResult Redecl = VisitRedeclarable(D);
@@ -2456,9 +2440,8 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   if (ThisDeclID == Redecl.getFirstID()) {
     // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all of
     // the specializations.
-    SmallVector<GlobalDeclID, 32> SpecIDs;
-    readDeclIDList(SpecIDs);
-    ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor, /*IsPartial=*/false);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor, /*IsPartial=*/true);
   }
 
   if (D->getTemplatedDecl()->TemplateOrInstantiation) {
@@ -2484,9 +2467,8 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
   if (ThisDeclID == Redecl.getFirstID()) {
     // This VarTemplateDecl owns a CommonPtr; read it to keep track of all of
     // the specializations.
-    SmallVector<GlobalDeclID, 32> SpecIDs;
-    readDeclIDList(SpecIDs);
-    ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor, /*IsPartial=*/false);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor, /*IsPartial=*/true);
   }
 }
 
@@ -2585,9 +2567,7 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
 
   if (ThisDeclID == Redecl.getFirstID()) {
     // This FunctionTemplateDecl owns a CommonPtr; read it.
-    SmallVector<GlobalDeclID, 32> SpecIDs;
-    readDeclIDList(SpecIDs);
-    ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor, /*IsPartial=*/false);
   }
 }
 
@@ -3877,6 +3857,8 @@ Decl *ASTReader::ReadDeclRecord(GlobalDeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECIALIZATIONS:
+  case DECL_PARTIAL_SPECIALIZATIONS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4286,8 +4268,6 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
   ProcessingUpdatesRAIIObj ProcessingUpdates(*this);
   DeclUpdateOffsetsMap::iterator UpdI = DeclUpdateOffsets.find(ID);
 
-  SmallVector<GlobalDeclID, 8> PendingLazySpecializationIDs;
-
   if (UpdI != DeclUpdateOffsets.end()) {
     auto UpdateOffsets = std::move(UpdI->second);
     DeclUpdateOffsets.erase(UpdI);
@@ -4324,7 +4304,7 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
 
       ASTDeclReader Reader(*this, Record, RecordLocation(F, Offset), ID,
                            SourceLocation());
-      Reader.UpdateDecl(D, PendingLazySpecializationIDs);
+      Reader.UpdateDecl(D);
 
       // We might have made this declaration interesting. If so, remember that
       // we need to hand it off to the consumer.
@@ -4334,17 +4314,6 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
       }
     }
   }
-  // Add the lazy specializations to the template.
-  assert((PendingLazySpecializationIDs.empty() || isa<ClassTemplateDecl>(D) ||
-          isa<FunctionTemplateDecl, VarTemplateDecl>(D)) &&
-         "Must not have pending specializations");
-  if (auto *CTD = dyn_cast<ClassTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(CTD, PendingLazySpecializationIDs);
-  else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(FTD, PendingLazySpecializationIDs);
-  else if (auto *VTD = dyn_cast<VarTemplateDecl>(D))
-    ASTDeclReader::AddLazySpecializations(VTD, PendingLazySpecializationIDs);
-  PendingLazySpecializationIDs.clear();
 
   // Load the pending visible updates for this decl context, if it has any.
   auto I = PendingVisibleUpdates.find(ID);
@@ -4369,6 +4338,26 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
       FunctionToLambdasMap.erase(IT);
     }
   }
+
+  // Load the pending specializations update for this decl, if it has any.
+  if (auto I = PendingSpecializationsUpdates.find(ID);
+      I != PendingSpecializationsUpdates.end()) {
+    auto SpecializationUpdates = std::move(I->second);
+    PendingSpecializationsUpdates.erase(I);
+
+    for (const auto &Update : SpecializationUpdates)
+      AddSpecializations(D, Update.Data, *Update.Mod, /*IsPartial=*/false);
+  }
+
+  // Load the pending specializations update for this decl, if it has any.
+  if (auto I = PendingPartialSpecializationsUpdates.find(ID);
+      I != PendingPartialSpecializationsUpdates.end()) {
+    auto SpecializationUpdates = std::move(I->second);
+    PendingPartialSpecializationsUpdates.erase(I);
+
+    for (const auto &Update : SpecializationUpdates)
+      AddSpecializations(D, Update.Data, *Update.Mod, /*IsPartial=*/true);
+  }
 }
 
 void ASTReader::loadPendingDeclChain(Decl *FirstLocal, uint64_t LocalOffset) {
@@ -4561,9 +4550,7 @@ static void forAllLaterRedecls(DeclT *D, Fn F) {
   }
 }
 
-void ASTDeclReader::UpdateDecl(
-    Decl *D,
-    llvm::SmallVectorImpl<GlobalDeclID> &PendingLazySpecializationIDs) {
+void ASTDeclReader::UpdateDecl(Decl *D) {
   while (Record.getIdx() < Record.size()) {
     switch ((DeclUpdateKind)Record.readInt()) {
     case UPD_CXX_ADDED_IMPLICIT_MEMBER: {
@@ -4574,11 +4561,6 @@ void ASTDeclReader::UpdateDecl(
       break;
     }
 
-    case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
-      // It will be added to the template's lazy specialization set.
-      PendingLazySpecializationIDs.push_back(readDeclID());
-      break;
-
     case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE: {
       auto *Anon = readDeclAs<NamespaceDecl>();
 

diff  --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 4f7e6f4b2741b7..be0d22d1f4094f 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,88 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+using LazySpecializationInfo = GlobalDeclID;
+
+/// Class that performs lookup to specialized decls.
+class LazySpecializationInfoLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<LazySpecializationInfo, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<LazySpecializationInfo> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(LazySpecializationInfo Info) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == Info)
+              return;
+          Data.push_back(Info);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(Info).second)
+        Data.push_back(Info);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (LazySpecializationInfo Info : From)
+      To.insert(Info);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct LazySpecializationInfoLookupTable {
+  MultiOnDiskHashTable<LazySpecializationInfoLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///

diff  --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index f8158a654c45aa..83fbb705e48c7c 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -13,6 +13,7 @@
 #include "ASTCommon.h"
 #include "ASTReaderInternals.h"
 #include "MultiOnDiskHashTable.h"
+#include "TemplateArgumentHasher.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTUnresolvedSet.h"
 #include "clang/AST/AbstractTypeWriter.h"
@@ -4167,6 +4168,175 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class LazySpecializationInfoLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 64> Specs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into Specs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTWriter &Writer)
+      : Writer(Writer) {}
+
+  template <typename Col, typename Col2>
+  data_type getData(Col &&C, Col2 &ExistingInfo) {
+    unsigned Start = Specs.size();
+    for (auto *D : C) {
+      NamedDecl *ND = getDeclForLocalLookup(Writer.getLangOpts(),
+                                            const_cast<NamedDecl *>(D));
+      Specs.push_back(GlobalDeclID(Writer.GetDeclRef(ND).getRawValue()));
+    }
+    for (const serialization::reader::LazySpecializationInfo &Info :
+         ExistingInfo)
+      Specs.push_back(Info);
+    return std::make_pair(Start, Specs.size());
+  }
+
+  data_type ImportData(
+      const reader::LazySpecializationInfoLookupTrait::data_type &FromReader) {
+    unsigned Start = Specs.size();
+    for (auto ID : FromReader)
+      Specs.push_back(ID);
+    return std::make_pair(Start, Specs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = sizeof(serialization::reader::LazySpecializationInfo) *
+                       (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I) {
+      LE.write<DeclID>(Specs[I].getRawValue());
+    }
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    Args = CTSD->getTemplateArgs().asArray();
+  else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    Args = VTSD->getTemplateArgs().asArray();
+  else if (auto *FD = dyn_cast<FunctionDecl>(Spec))
+    Args = FD->getTemplateSpecializationArgs()->asArray();
+  else
+    llvm_unreachable("New Specialization Kind?");
+
+  return StableHashForTemplateArguments(Args);
+}
+} // namespace
+
+void ASTWriter::GenerateSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+    llvm::SmallVectorImpl<char> &LookupTable, bool IsPartial) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::LazySpecializationInfoLookupTrait,
+                                LazySpecializationInfoLookupTrait>
+      Generator;
+  LazySpecializationInfoLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+      SpecializationMaps;
+
+  for (auto *Specialization : Specializations) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+    auto Iter = SpecializationMaps.find(HashedValue);
+    if (Iter == SpecializationMaps.end())
+      Iter = SpecializationMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(cast<NamedDecl>(Specialization));
+  }
+
+  auto *Lookups =
+      Chain ? Chain->getLoadedSpecializationsLookupTables(D, IsPartial)
+            : nullptr;
+
+  for (auto &[HashValue, Specs] : SpecializationMaps) {
+    SmallVector<serialization::reader::LazySpecializationInfo, 16>
+        ExisitingSpecs;
+    // We have to merge the lookup table manually here. We can't depend on the
+    // merge mechanism offered by
+    // clang::serialization::MultiOnDiskHashTableGenerator since that generator
+    // assumes the we'll get the same value with the same key.
+    // And also underlying llvm::OnDiskChainedHashTableGenerator assumes that we
+    // won't insert the values with the same key twice. So we have to merge the
+    // lookup table here manually.
+    if (Lookups)
+      ExisitingSpecs = Lookups->Table.find(HashValue);
+
+    Generator.insert(HashValue, Trait.getData(Specs, ExisitingSpecs), Trait);
+  }
+
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+}
+
+uint64_t ASTWriter::WriteSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+    bool IsPartial) {
+
+  llvm::SmallString<4096> LookupTable;
+  GenerateSpecializationInfoLookupTable(D, Specializations, LookupTable,
+                                        IsPartial);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+  RecordData::value_type Record[] = {IsPartial ? DECL_PARTIAL_SPECIALIZATIONS
+                                               : DECL_SPECIALIZATIONS};
+  Stream.EmitRecordWithBlob(IsPartial ? DeclPartialSpecializationsAbbrev
+                                      : DeclSpecializationsAbbrev,
+                            Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5748,7 +5918,7 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
   RecordData DeclUpdatesOffsetsRecord;
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/ 6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
@@ -5822,6 +5992,16 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
                       FunctionToLambdaMapAbbrev);
   }
 
+  if (!SpecializationsUpdates.empty()) {
+    WriteSpecializationsUpdates(/*IsPartial=*/false);
+    SpecializationsUpdates.clear();
+  }
+
+  if (!PartialSpecializationsUpdates.empty()) {
+    WriteSpecializationsUpdates(/*IsPartial=*/true);
+    PartialSpecializationsUpdates.clear();
+  }
+
   const TranslationUnitDecl *TU = Context.getTranslationUnitDecl();
   // Create a lexical update block containing all of the declarations in the
   // translation unit that do not come from other AST files.
@@ -5865,6 +6045,31 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
     WriteDeclContextVisibleUpdate(Context, DC);
 }
 
+void ASTWriter::WriteSpecializationsUpdates(bool IsPartial) {
+  auto RecordType = IsPartial ? CXX_ADDED_TEMPLATE_PARTIAL_SPECIALIZATION
+                              : CXX_ADDED_TEMPLATE_SPECIALIZATION;
+
+  auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
+  Abv->Add(llvm::BitCodeAbbrevOp(RecordType));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::VBR, 6));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::Blob));
+  auto UpdateSpecializationAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  auto &SpecUpdates =
+      IsPartial ? PartialSpecializationsUpdates : SpecializationsUpdates;
+  for (auto &SpecializationUpdate : SpecUpdates) {
+    const NamedDecl *D = SpecializationUpdate.first;
+
+    llvm::SmallString<4096> LookupTable;
+    GenerateSpecializationInfoLookupTable(D, SpecializationUpdate.second,
+                                          LookupTable, IsPartial);
+
+    // Write the lookup table
+    RecordData::value_type Record[] = {RecordType, getDeclID(D).getRawValue()};
+    Stream.EmitRecordWithBlob(UpdateSpecializationAbbrev, Record, LookupTable);
+  }
+}
+
 void ASTWriter::WriteDeclUpdatesBlocks(ASTContext &Context,
                                        RecordDataImpl &OffsetsRecord) {
   if (DeclUpdates.empty())
@@ -5894,12 +6099,10 @@ void ASTWriter::WriteDeclUpdatesBlocks(ASTContext &Context,
 
       switch (Kind) {
       case UPD_CXX_ADDED_IMPLICIT_MEMBER:
-      case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
       case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE:
         assert(Update.getDecl() && "no decl to add?");
         Record.AddDeclRef(Update.getDecl());
         break;
-
       case UPD_CXX_ADDED_FUNCTION_DEFINITION:
       case UPD_CXX_ADDED_VAR_DEFINITION:
         break;

diff  --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index ad357e30d57529..b3119607a14043 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -177,11 +177,12 @@ namespace clang {
       Record.AddSourceLocation(typeParams->getRAngleLoc());
     }
 
-    /// Add to the record the first declaration from each module file that
-    /// provides a declaration of D. The intent is to provide a sufficient
-    /// set such that reloading this set will load all current redeclarations.
-    void AddFirstDeclFromEachModule(const Decl *D, bool IncludeLocal) {
-      llvm::MapVector<ModuleFile*, const Decl*> Firsts;
+    /// Collect the first declaration from each module file that provides a
+    /// declaration of D.
+    void CollectFirstDeclFromEachModule(
+        const Decl *D, bool IncludeLocal,
+        llvm::MapVector<ModuleFile *, const Decl *> &Firsts) {
+
       // FIXME: We can skip entries that we know are implied by others.
       for (const Decl *R = D->getMostRecentDecl(); R; R = R->getPreviousDecl()) {
         if (R->isFromASTFile())
@@ -189,10 +190,41 @@ namespace clang {
         else if (IncludeLocal)
           Firsts[nullptr] = R;
       }
+    }
+
+    /// Add to the record the first declaration from each module file that
+    /// provides a declaration of D. The intent is to provide a sufficient
+    /// set such that reloading this set will load all current redeclarations.
+    void AddFirstDeclFromEachModule(const Decl *D, bool IncludeLocal) {
+      llvm::MapVector<ModuleFile *, const Decl *> Firsts;
+      CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+
       for (const auto &F : Firsts)
         Record.AddDeclRef(F.second);
     }
 
+    /// Add to the record the first template specialization from each module
+    /// file that provides a declaration of D. We store the DeclId and an
+    /// ODRHash of the template arguments of D which should provide enough
+    /// information to load D only if the template instantiator needs it.
+    void AddFirstSpecializationDeclFromEachModule(
+        const Decl *D, llvm::SmallVectorImpl<const Decl *> &SpecsInMap,
+        llvm::SmallVectorImpl<const Decl *> &PartialSpecsInMap) {
+      assert((isa<ClassTemplateSpecializationDecl>(D) ||
+              isa<VarTemplateSpecializationDecl>(D) || isa<FunctionDecl>(D)) &&
+             "Must not be called with other decls");
+      llvm::MapVector<ModuleFile *, const Decl *> Firsts;
+      CollectFirstDeclFromEachModule(D, /*IncludeLocal*/ true, Firsts);
+
+      for (const auto &F : Firsts) {
+        if (isa<ClassTemplatePartialSpecializationDecl,
+                VarTemplatePartialSpecializationDecl>(F.second))
+          PartialSpecsInMap.push_back(F.second);
+        else
+          SpecsInMap.push_back(F.second);
+      }
+    }
+
     /// Get the specialization decl from an entry in the specialization list.
     template <typename EntryType>
     typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
@@ -205,8 +237,9 @@ namespace clang {
     decltype(T::PartialSpecializations) &getPartialSpecializations(T *Common) {
       return Common->PartialSpecializations;
     }
-    ArrayRef<Decl> getPartialSpecializations(FunctionTemplateDecl::Common *) {
-      return {};
+    MutableArrayRef<FunctionTemplateSpecializationInfo>
+    getPartialSpecializations(FunctionTemplateDecl::Common *) {
+      return std::nullopt;
     }
 
     template<typename DeclTy>
@@ -217,37 +250,37 @@ namespace clang {
       // our chained AST reader, we can just write out the DeclIDs. Otherwise,
       // we need to resolve them to actual declarations.
       if (Writer.Chain != Record.getASTContext().getExternalSource() &&
-          Common->LazySpecializations) {
+          Writer.Chain && Writer.Chain->haveUnloadedSpecializations(D)) {
         D->LoadLazySpecializations();
-        assert(!Common->LazySpecializations);
+        assert(!Writer.Chain->haveUnloadedSpecializations(D));
       }
 
-      ArrayRef<GlobalDeclID> LazySpecializations;
-      if (auto *LS = Common->LazySpecializations)
-        LazySpecializations = llvm::ArrayRef(LS + 1, LS[0].getRawValue());
-
-      // Add a slot to the record for the number of specializations.
-      unsigned I = Record.size();
-      Record.push_back(0);
-
-      // AddFirstDeclFromEachModule might trigger deserialization, invalidating
-      // *Specializations iterators.
-      llvm::SmallVector<const Decl*, 16> Specs;
+      // AddFirstSpecializationDeclFromEachModule might trigger deserialization,
+      // invalidating *Specializations iterators.
+      llvm::SmallVector<const Decl *, 16> AllSpecs;
       for (auto &Entry : Common->Specializations)
-        Specs.push_back(getSpecializationDecl(Entry));
+        AllSpecs.push_back(getSpecializationDecl(Entry));
       for (auto &Entry : getPartialSpecializations(Common))
-        Specs.push_back(getSpecializationDecl(Entry));
+        AllSpecs.push_back(getSpecializationDecl(Entry));
 
-      for (auto *D : Specs) {
+      llvm::SmallVector<const Decl *, 16> Specs;
+      llvm::SmallVector<const Decl *, 16> PartialSpecs;
+      for (auto *D : AllSpecs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
-        AddFirstDeclFromEachModule(D, /*IncludeLocal*/true);
+        AddFirstSpecializationDeclFromEachModule(D, Specs, PartialSpecs);
+      }
+
+      Record.AddOffset(Writer.WriteSpecializationInfoLookupTable(
+          D, Specs, /*IsPartial=*/false));
+
+      // Function Template Decl doesn't have partial decls.
+      if (isa<FunctionTemplateDecl>(D)) {
+        assert(PartialSpecs.empty());
+        return;
       }
-      Record.append(
-          DeclIDIterator<GlobalDeclID, DeclID>(LazySpecializations.begin()),
-          DeclIDIterator<GlobalDeclID, DeclID>(LazySpecializations.end()));
 
-      // Update the size entry we added earlier.
-      Record[I] = Record.size() - I - 1;
+      Record.AddOffset(Writer.WriteSpecializationInfoLookupTable(
+          D, PartialSpecs, /*IsPartial=*/true));
     }
 
     /// Ensure that this template specialization is associated with the specified
@@ -268,8 +301,13 @@ namespace clang {
       if (Writer.getFirstLocalDecl(Specialization) != Specialization)
         return;
 
-      Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
-          UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION, Specialization));
+      if (isa<ClassTemplatePartialSpecializationDecl,
+              VarTemplatePartialSpecializationDecl>(Specialization))
+        Writer.PartialSpecializationsUpdates[cast<NamedDecl>(Template)]
+            .push_back(cast<NamedDecl>(Specialization));
+      else
+        Writer.SpecializationsUpdates[cast<NamedDecl>(Template)].push_back(
+            cast<NamedDecl>(Specialization));
     }
   };
 }
@@ -2774,6 +2812,16 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_PARTIAL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclPartialSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by

diff  --git a/clang/lib/Serialization/CMakeLists.txt b/clang/lib/Serialization/CMakeLists.txt
index 99c47c15a2f479..b1fc0345047f24 100644
--- a/clang/lib/Serialization/CMakeLists.txt
+++ b/clang/lib/Serialization/CMakeLists.txt
@@ -23,6 +23,7 @@ add_clang_library(clangSerialization
   ModuleManager.cpp
   PCHContainerOperations.cpp
   ObjectFilePCHContainerReader.cpp
+  TemplateArgumentHasher.cpp
 
   ADDITIONAL_HEADERS
   ASTCommon.h

diff  --git a/clang/lib/Serialization/TemplateArgumentHasher.cpp b/clang/lib/Serialization/TemplateArgumentHasher.cpp
new file mode 100644
index 00000000000000..598f098f526d0f
--- /dev/null
+++ b/clang/lib/Serialization/TemplateArgumentHasher.cpp
@@ -0,0 +1,409 @@
+//===- TemplateArgumentHasher.cpp - Hash Template Arguments -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TemplateArgumentHasher.h"
+#include "clang/AST/APValue.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
+#include "clang/AST/DeclarationName.h"
+#include "clang/AST/TypeVisitor.h"
+#include "clang/Basic/IdentifierTable.h"
+#include "llvm/ADT/FoldingSet.h"
+
+using namespace clang;
+
+namespace {
+
+class TemplateArgumentHasher {
+  // If we bail out during the process of calculating hash values for
+  // template arguments for any reason. We're allowed to do it since
+  // TemplateArgumentHasher are only required to give the same hash value
+  // for the same template arguments, but not required to give 
diff erent
+  // hash value for 
diff erent template arguments.
+  //
+  // So in the worst case, it is still a valid implementation to give all
+  // inputs the same BailedOutValue as output.
+  bool BailedOut = false;
+  static constexpr unsigned BailedOutValue = 0x12345678;
+
+  llvm::FoldingSetNodeID ID;
+
+public:
+  TemplateArgumentHasher() = default;
+
+  void AddTemplateArgument(TemplateArgument TA);
+
+  void AddInteger(unsigned V) { ID.AddInteger(V); }
+
+  unsigned getValue() {
+    if (BailedOut)
+      return BailedOutValue;
+
+    return ID.computeStableHash();
+  }
+
+  void setBailedOut() { BailedOut = true; }
+
+  void AddType(const Type *T);
+  void AddQualType(QualType T);
+  void AddDecl(const Decl *D);
+  void AddStructuralValue(const APValue &);
+  void AddTemplateName(TemplateName Name);
+  void AddDeclarationName(DeclarationName Name);
+  void AddIdentifierInfo(const IdentifierInfo *II);
+};
+
+void TemplateArgumentHasher::AddTemplateArgument(TemplateArgument TA) {
+  const auto Kind = TA.getKind();
+  AddInteger(Kind);
+
+  switch (Kind) {
+  case TemplateArgument::Null:
+    llvm_unreachable("Expected valid TemplateArgument");
+  case TemplateArgument::Type:
+    AddQualType(TA.getAsType());
+    break;
+  case TemplateArgument::Declaration:
+    AddDecl(TA.getAsDecl());
+    break;
+  case TemplateArgument::NullPtr:
+    ID.AddPointer(nullptr);
+    break;
+  case TemplateArgument::Integral: {
+    // There are integrals (e.g.: _BitInt(128)) that cannot be represented as
+    // any builtin integral type, so we use the hash of APSInt instead.
+    TA.getAsIntegral().Profile(ID);
+    break;
+  }
+  case TemplateArgument::StructuralValue:
+    AddQualType(TA.getStructuralValueType());
+    AddStructuralValue(TA.getAsStructuralValue());
+    break;
+  case TemplateArgument::Template:
+  case TemplateArgument::TemplateExpansion:
+    AddTemplateName(TA.getAsTemplateOrTemplatePattern());
+    break;
+  case TemplateArgument::Expression:
+    // If we meet expression in template argument, it implies
+    // that the template is still dependent. It is meaningless
+    // to get a stable hash for the template. Bail out simply.
+    BailedOut = true;
+    break;
+  case TemplateArgument::Pack:
+    AddInteger(TA.pack_size());
+    for (auto SubTA : TA.pack_elements()) {
+      AddTemplateArgument(SubTA);
+    }
+    break;
+  }
+}
+
+void TemplateArgumentHasher::AddStructuralValue(const APValue &Value) {
+  auto Kind = Value.getKind();
+  AddInteger(Kind);
+
+  // 'APValue::Profile' uses pointer values to make hash for LValue and
+  // MemberPointer, but they 
diff er from one compiler invocation to another.
+  // It may be 
diff icult to handle such cases. Bail out simply.
+
+  if (Kind == APValue::LValue || Kind == APValue::MemberPointer) {
+    BailedOut = true;
+    return;
+  }
+
+  Value.Profile(ID);
+}
+
+void TemplateArgumentHasher::AddTemplateName(TemplateName Name) {
+  switch (Name.getKind()) {
+  case TemplateName::Template:
+    AddDecl(Name.getAsTemplateDecl());
+    break;
+  case TemplateName::QualifiedTemplate: {
+    QualifiedTemplateName *QTN = Name.getAsQualifiedTemplateName();
+    AddTemplateName(QTN->getUnderlyingTemplate());
+    break;
+  }
+  case TemplateName::OverloadedTemplate:
+  case TemplateName::AssumedTemplate:
+  case TemplateName::DependentTemplate:
+  case TemplateName::SubstTemplateTemplateParm:
+  case TemplateName::SubstTemplateTemplateParmPack:
+    BailedOut = true;
+    break;
+  case TemplateName::UsingTemplate: {
+    UsingShadowDecl *USD = Name.getAsUsingShadowDecl();
+    if (USD)
+      AddDecl(USD->getTargetDecl());
+    else
+      BailedOut = true;
+    break;
+  }
+  case TemplateName::DeducedTemplate:
+    AddTemplateName(Name.getAsDeducedTemplateName()->getUnderlying());
+    break;
+  }
+}
+
+void TemplateArgumentHasher::AddIdentifierInfo(const IdentifierInfo *II) {
+  assert(II && "Expecting non-null pointer.");
+  ID.AddString(II->getName());
+}
+
+void TemplateArgumentHasher::AddDeclarationName(DeclarationName Name) {
+  if (Name.isEmpty())
+    return;
+
+  switch (Name.getNameKind()) {
+  case DeclarationName::Identifier:
+    AddIdentifierInfo(Name.getAsIdentifierInfo());
+    break;
+  case DeclarationName::ObjCZeroArgSelector:
+  case DeclarationName::ObjCOneArgSelector:
+  case DeclarationName::ObjCMultiArgSelector:
+    BailedOut = true;
+    break;
+  case DeclarationName::CXXConstructorName:
+  case DeclarationName::CXXDestructorName:
+    AddQualType(Name.getCXXNameType());
+    break;
+  case DeclarationName::CXXOperatorName:
+    AddInteger(Name.getCXXOverloadedOperator());
+    break;
+  case DeclarationName::CXXLiteralOperatorName:
+    AddIdentifierInfo(Name.getCXXLiteralIdentifier());
+    break;
+  case DeclarationName::CXXConversionFunctionName:
+    AddQualType(Name.getCXXNameType());
+    break;
+  case DeclarationName::CXXUsingDirective:
+    break;
+  case DeclarationName::CXXDeductionGuideName: {
+    if (auto *Template = Name.getCXXDeductionGuideTemplate())
+      AddDecl(Template);
+  }
+  }
+}
+
+void TemplateArgumentHasher::AddDecl(const Decl *D) {
+  const NamedDecl *ND = dyn_cast<NamedDecl>(D);
+  if (!ND) {
+    BailedOut = true;
+    return;
+  }
+
+  AddDeclarationName(ND->getDeclName());
+}
+
+void TemplateArgumentHasher::AddQualType(QualType T) {
+  if (T.isNull()) {
+    BailedOut = true;
+    return;
+  }
+  SplitQualType split = T.split();
+  AddInteger(split.Quals.getAsOpaqueValue());
+  AddType(split.Ty);
+}
+
+// Process a Type pointer.  Add* methods call back into TemplateArgumentHasher
+// while Visit* methods process the relevant parts of the Type.
+// Any unhandled type will make the hash computation bail out.
+class TypeVisitorHelper : public TypeVisitor<TypeVisitorHelper> {
+  typedef TypeVisitor<TypeVisitorHelper> Inherited;
+  llvm::FoldingSetNodeID &ID;
+  TemplateArgumentHasher &Hash;
+
+public:
+  TypeVisitorHelper(llvm::FoldingSetNodeID &ID, TemplateArgumentHasher &Hash)
+      : ID(ID), Hash(Hash) {}
+
+  void AddDecl(const Decl *D) {
+    if (D)
+      Hash.AddDecl(D);
+    else
+      Hash.AddInteger(0);
+  }
+
+  void AddQualType(QualType T) { Hash.AddQualType(T); }
+
+  void AddType(const Type *T) {
+    if (T)
+      Hash.AddType(T);
+    else
+      Hash.AddInteger(0);
+  }
+
+  void VisitQualifiers(Qualifiers Quals) {
+    Hash.AddInteger(Quals.getAsOpaqueValue());
+  }
+
+  void Visit(const Type *T) { Inherited::Visit(T); }
+
+  // Unhandled types. Bail out simply.
+  void VisitType(const Type *T) { Hash.setBailedOut(); }
+
+  void VisitAdjustedType(const AdjustedType *T) {
+    AddQualType(T->getOriginalType());
+  }
+
+  void VisitDecayedType(const DecayedType *T) {
+    // getDecayedType and getPointeeType are derived from getAdjustedType
+    // and don't need to be separately processed.
+    VisitAdjustedType(T);
+  }
+
+  void VisitArrayType(const ArrayType *T) {
+    AddQualType(T->getElementType());
+    Hash.AddInteger(llvm::to_underlying(T->getSizeModifier()));
+    VisitQualifiers(T->getIndexTypeQualifiers());
+  }
+  void VisitConstantArrayType(const ConstantArrayType *T) {
+    T->getSize().Profile(ID);
+    VisitArrayType(T);
+  }
+
+  void VisitAttributedType(const AttributedType *T) {
+    Hash.AddInteger(T->getAttrKind());
+    AddQualType(T->getModifiedType());
+  }
+
+  void VisitBuiltinType(const BuiltinType *T) { Hash.AddInteger(T->getKind()); }
+
+  void VisitComplexType(const ComplexType *T) {
+    AddQualType(T->getElementType());
+  }
+
+  void VisitDecltypeType(const DecltypeType *T) {
+    AddQualType(T->getUnderlyingType());
+  }
+
+  void VisitDeducedType(const DeducedType *T) {
+    AddQualType(T->getDeducedType());
+  }
+
+  void VisitAutoType(const AutoType *T) { VisitDeducedType(T); }
+
+  void VisitDeducedTemplateSpecializationType(
+      const DeducedTemplateSpecializationType *T) {
+    Hash.AddTemplateName(T->getTemplateName());
+    VisitDeducedType(T);
+  }
+
+  void VisitFunctionType(const FunctionType *T) {
+    AddQualType(T->getReturnType());
+    T->getExtInfo().Profile(ID);
+    Hash.AddInteger(T->isConst());
+    Hash.AddInteger(T->isVolatile());
+    Hash.AddInteger(T->isRestrict());
+  }
+
+  void VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
+    VisitFunctionType(T);
+  }
+
+  void VisitFunctionProtoType(const FunctionProtoType *T) {
+    Hash.AddInteger(T->getNumParams());
+    for (auto ParamType : T->getParamTypes())
+      AddQualType(ParamType);
+
+    VisitFunctionType(T);
+  }
+
+  void VisitMemberPointerType(const MemberPointerType *T) {
+    AddQualType(T->getPointeeType());
+    AddType(T->getClass());
+  }
+
+  void VisitPackExpansionType(const PackExpansionType *T) {
+    AddQualType(T->getPattern());
+  }
+
+  void VisitParenType(const ParenType *T) { AddQualType(T->getInnerType()); }
+
+  void VisitPointerType(const PointerType *T) {
+    AddQualType(T->getPointeeType());
+  }
+
+  void VisitReferenceType(const ReferenceType *T) {
+    AddQualType(T->getPointeeTypeAsWritten());
+  }
+
+  void VisitLValueReferenceType(const LValueReferenceType *T) {
+    VisitReferenceType(T);
+  }
+
+  void VisitRValueReferenceType(const RValueReferenceType *T) {
+    VisitReferenceType(T);
+  }
+
+  void
+  VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {
+    AddDecl(T->getAssociatedDecl());
+    Hash.AddTemplateArgument(T->getArgumentPack());
+  }
+
+  void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) {
+    AddDecl(T->getAssociatedDecl());
+    AddQualType(T->getReplacementType());
+  }
+
+  void VisitTagType(const TagType *T) { AddDecl(T->getDecl()); }
+
+  void VisitRecordType(const RecordType *T) { VisitTagType(T); }
+  void VisitEnumType(const EnumType *T) { VisitTagType(T); }
+
+  void VisitTemplateSpecializationType(const TemplateSpecializationType *T) {
+    Hash.AddInteger(T->template_arguments().size());
+    for (const auto &TA : T->template_arguments()) {
+      Hash.AddTemplateArgument(TA);
+    }
+    Hash.AddTemplateName(T->getTemplateName());
+  }
+
+  void VisitTemplateTypeParmType(const TemplateTypeParmType *T) {
+    Hash.AddInteger(T->getDepth());
+    Hash.AddInteger(T->getIndex());
+    Hash.AddInteger(T->isParameterPack());
+  }
+
+  void VisitTypedefType(const TypedefType *T) { AddDecl(T->getDecl()); }
+
+  void VisitElaboratedType(const ElaboratedType *T) {
+    AddQualType(T->getNamedType());
+  }
+
+  void VisitUnaryTransformType(const UnaryTransformType *T) {
+    AddQualType(T->getUnderlyingType());
+    AddQualType(T->getBaseType());
+  }
+
+  void VisitVectorType(const VectorType *T) {
+    AddQualType(T->getElementType());
+    Hash.AddInteger(T->getNumElements());
+    Hash.AddInteger(llvm::to_underlying(T->getVectorKind()));
+  }
+
+  void VisitExtVectorType(const ExtVectorType *T) { VisitVectorType(T); }
+};
+
+void TemplateArgumentHasher::AddType(const Type *T) {
+  assert(T && "Expecting non-null pointer.");
+  TypeVisitorHelper(ID, *this).Visit(T);
+}
+
+} // namespace
+
+unsigned clang::serialization::StableHashForTemplateArguments(
+    llvm::ArrayRef<TemplateArgument> Args) {
+  TemplateArgumentHasher Hasher;
+  Hasher.AddInteger(Args.size());
+  for (TemplateArgument Arg : Args)
+    Hasher.AddTemplateArgument(Arg);
+  return Hasher.getValue();
+}

diff  --git a/clang/lib/Serialization/TemplateArgumentHasher.h b/clang/lib/Serialization/TemplateArgumentHasher.h
new file mode 100644
index 00000000000000..f23f1318afbbf4
--- /dev/null
+++ b/clang/lib/Serialization/TemplateArgumentHasher.h
@@ -0,0 +1,34 @@
+//===- TemplateArgumentHasher.h - Hash Template Arguments -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/TemplateBase.h"
+
+namespace clang {
+namespace serialization {
+
+/// Calculate a stable hash value for template arguments. We guarantee that
+/// the same template arguments must have the same hashed values. But we don't
+/// guarantee that the template arguments with the same hashed value are the
+/// same template arguments.
+///
+/// ODR hashing may not be the best mechanism to hash the template
+/// arguments. ODR hashing is (or perhaps, should be) about determining whether
+/// two things are spelled the same way and have the same meaning (as required
+/// by the C++ ODR), whereas what we want here is whether they have the same
+/// meaning regardless of spelling. Maybe we can get away with reusing ODR
+/// hashing anyway, on the basis that any canonical, non-dependent template
+/// argument should have the same (invented) spelling in every translation
+/// unit, but it is not sure that's true in all cases. There may still be cases
+/// where the canonical type includes some aspect of "whatever we saw first",
+/// in which case the ODR hash can 
diff er across translation units for
+/// non-dependent, canonical template arguments that are spelled 
diff erently
+/// but have the same meaning. But it is not easy to raise examples.
+unsigned StableHashForTemplateArguments(llvm::ArrayRef<TemplateArgument> Args);
+
+} // namespace serialization
+} // namespace clang

diff  --git a/clang/test/Modules/odr_hash.cpp b/clang/test/Modules/odr_hash.cpp
index f1de6b3d433ed7..4de0e50dbc0eb7 100644
--- a/clang/test/Modules/odr_hash.cpp
+++ b/clang/test/Modules/odr_hash.cpp
@@ -3084,8 +3084,8 @@ struct S5 {
 };
 #else
 S5 s5;
-// expected-error at second.h:* {{'PointersAndReferences::S5::x' from module 'SecondModule' is not present in definition of 'PointersAndReferences::S5' in module 'FirstModule'}}
-// expected-note at first.h:* {{declaration of 'x' does not match}}
+// expected-error at first.h:* {{'PointersAndReferences::S5::x' from module 'FirstModule' is not present in definition of 'PointersAndReferences::S5' in module 'SecondModule'}}
+// expected-note at second.h:* {{declaration of 'x' does not match}}
 #endif
 
 #if defined(FIRST)

diff  --git a/clang/test/Modules/recursive-instantiations.cppm b/clang/test/Modules/recursive-instantiations.cppm
new file mode 100644
index 00000000000000..d5854b0e647e37
--- /dev/null
+++ b/clang/test/Modules/recursive-instantiations.cppm
@@ -0,0 +1,40 @@
+// RUN: rm -rf %t
+// RUN: mkdir -p %t
+// RUN: split-file %s %t
+//
+// RUN: %clang_cc1 -std=c++20 %t/type_traits.cppm -emit-module-interface -o %t/type_traits.pcm
+// RUN: %clang_cc1 -std=c++20 %t/test.cpp -fprebuilt-module-path=%t -verify
+
+//--- type_traits.cppm
+export module type_traits;
+
+export template <typename T>
+constexpr bool is_pod_v = __is_pod(T);
+
+//--- test.cpp
+// expected-no-diagnostics
+import type_traits;
+// Base is either void or wrapper<T>.
+template <class Base> struct wrapper : Base {};
+template <> struct wrapper<void> {};
+
+// wrap<0>::type<T> is wrapper<T>, wrap<1>::type<T> is wrapper<wrapper<T>>,
+// and so on.
+template <int N>
+struct wrap {
+  template <class Base>
+  using type = wrapper<typename wrap<N-1>::template type<Base>>;
+};
+
+template <>
+struct wrap<0> {
+  template <class Base>
+  using type = wrapper<Base>;
+};
+
+inline constexpr int kMaxRank = 40;
+template <int N, class Base = void>
+using rank = typename wrap<N>::template type<Base>;
+using rank_selector_t = rank<kMaxRank>;
+
+static_assert(is_pod_v<rank_selector_t>, "Must be POD");

diff  --git a/clang/test/OpenMP/target_parallel_ast_print.cpp b/clang/test/OpenMP/target_parallel_ast_print.cpp
index 7e27ac7b92ca4a..3ee98bc525c1bd 100644
--- a/clang/test/OpenMP/target_parallel_ast_print.cpp
+++ b/clang/test/OpenMP/target_parallel_ast_print.cpp
@@ -38,10 +38,6 @@ struct S {
 // CHECK:        static int TS;
 // CHECK-NEXT:   #pragma omp threadprivate(S<int>::TS)
 // CHECK-NEXT: }
-// CHECK:      template<> struct S<char> {
-// CHECK:        static char TS;
-// CHECK-NEXT:   #pragma omp threadprivate(S<char>::TS)
-// CHECK-NEXT: }
 
 template <typename T, int C>
 T tmain(T argc, T *argv) {

diff  --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp
index 8338f2a68f9228..cc47ae92efac0f 100644
--- a/clang/test/OpenMP/target_teams_ast_print.cpp
+++ b/clang/test/OpenMP/target_teams_ast_print.cpp
@@ -40,10 +40,6 @@ struct S {
 // CHECK:        static int TS;
 // CHECK-NEXT:   #pragma omp threadprivate(S<int>::TS)
 // CHECK-NEXT: }
-// CHECK:      template<> struct S<long> {
-// CHECK:        static long TS;
-// CHECK-NEXT:   #pragma omp threadprivate(S<long>::TS)
-// CHECK-NEXT: }
 
 template <typename T, int C>
 T tmain(T argc, T *argv) {

diff  --git a/clang/test/OpenMP/task_ast_print.cpp b/clang/test/OpenMP/task_ast_print.cpp
index 2a6b8908a1e2dc..30fb7ab75cc87a 100644
--- a/clang/test/OpenMP/task_ast_print.cpp
+++ b/clang/test/OpenMP/task_ast_print.cpp
@@ -87,10 +87,6 @@ struct S {
 // CHECK:        static int TS;
 // CHECK-NEXT:   #pragma omp threadprivate(S<int>::TS)
 // CHECK-NEXT: }
-// CHECK:      template<> struct S<long> {
-// CHECK:        static long TS;
-// CHECK-NEXT:   #pragma omp threadprivate(S<long>::TS)
-// CHECK-NEXT: }
 
 template <typename T, int C>
 T tmain(T argc, T *argv) {

diff  --git a/clang/test/OpenMP/teams_ast_print.cpp b/clang/test/OpenMP/teams_ast_print.cpp
index 0087f71ac9f742..597a9b2bdbdc5d 100644
--- a/clang/test/OpenMP/teams_ast_print.cpp
+++ b/clang/test/OpenMP/teams_ast_print.cpp
@@ -27,10 +27,6 @@ struct S {
 // CHECK:        static int TS;
 // CHECK-NEXT:   #pragma omp threadprivate(S<int>::TS)
 // CHECK-NEXT: }
-// CHECK:      template<> struct S<long> {
-// CHECK:        static long TS;
-// CHECK-NEXT:   #pragma omp threadprivate(S<long>::TS)
-// CHECK-NEXT: }
 
 template <typename T, int C>
 T tmain(T argc, T *argv) {

diff  --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index e7eebd0cb98239..e7005b5d511eba 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -11,6 +11,7 @@ add_clang_unittest(SerializationTests
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   PreambleInNamedModulesTest.cpp
+  LoadSpecLazilyTest.cpp
   SourceLocationEncodingTest.cpp
   VarDeclConstantInitTest.cpp
   )

diff  --git a/clang/unittests/Serialization/LoadSpecLazilyTest.cpp b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
new file mode 100644
index 00000000000000..0e452652a940d5
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
@@ -0,0 +1,262 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS =
+        llvm::vfs::createPhysicalFileSystem();
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(*VFS, new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = VFS;
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + ".pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    Instance.getFrontendOpts().OutputFile = CacheBMIPath;
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+enum class CheckingMode { Forbidden, Required };
+
+class DeclsReaderListener : public ASTDeserializationListener {
+  StringRef SpeficiedName;
+  CheckingMode Mode;
+
+  bool ReadedSpecifiedName = false;
+
+public:
+  void DeclRead(GlobalDeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    ReadedSpecifiedName |= ND->getName().contains(SpeficiedName);
+    if (Mode == CheckingMode::Forbidden) {
+      EXPECT_FALSE(ReadedSpecifiedName);
+    }
+  }
+
+  DeclsReaderListener(StringRef SpeficiedName, CheckingMode Mode)
+      : SpeficiedName(SpeficiedName), Mode(Mode) {}
+
+  ~DeclsReaderListener() {
+    if (Mode == CheckingMode::Required) {
+      EXPECT_TRUE(ReadedSpecifiedName);
+    }
+  }
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef SpecifiedName, CheckingMode Mode)
+      : Listener(SpecifiedName, Mode) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef SpecifiedName;
+  CheckingMode Mode;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(SpecifiedName, Mode);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef SpecifiedName, CheckingMode Mode)
+      : SpecifiedName(SpecifiedName), Mode(Mode) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+TEST_F(LoadSpecLazilyTest, ChainedTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+/// Test that we won't crash due to we may invalidate the lazy specialization
+/// lookup table during the loading process.
+TEST_F(LoadSpecLazilyTest, ChainedTest2) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class B {};
+
+export class C {
+  A<B> D;
+};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class MayBeLoaded {};
+
+export class Temp {
+   A<MayBeLoaded> AS;
+};
+
+export class ExportedClass {};
+
+export template<> class A<ExportedClass> {
+   A<MayBeLoaded> AS;
+   A<B>           AB;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+Temp T;
+A<ExportedClass> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                        "MayBeLoaded", CheckingMode::Required),
+                                    test_file_contents,
+                                    {
+                                        "-std=c++20",
+                                        DepArg.c_str(),
+                                        "-I",
+                                        TestDir.c_str(),
+                                    },
+                                    "test.cpp"));
+}
+
+} // namespace


        


More information about the cfe-commits mailing list