[llvm-branch-commits] [clang] [Serialization] Introduce OnDiskHashTable for specializations (PR #83233)
Chuanqi Xu via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Nov 8 01:20:07 PST 2024
https://github.com/ChuanqiXu9 updated https://github.com/llvm/llvm-project/pull/83233
>From 11726437efb760c9f2aba9b2258337b2b8eb4bb6 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Fri, 8 Nov 2024 17:19:33 +0800
Subject: [PATCH] [Serialization] Introduce OnDiskHashTable for specializations
---
clang/include/clang/AST/ExternalASTSource.h | 11 +
.../clang/Sema/MultiplexExternalSemaSource.h | 6 +
.../include/clang/Serialization/ASTBitCodes.h | 6 +
clang/include/clang/Serialization/ASTReader.h | 34 ++-
clang/include/clang/Serialization/ASTWriter.h | 14 +
clang/lib/AST/DeclTemplate.cpp | 17 ++
clang/lib/AST/ExternalASTSource.cpp | 5 +
.../lib/Sema/MultiplexExternalSemaSource.cpp | 12 +
clang/lib/Serialization/ASTReader.cpp | 145 +++++++++-
clang/lib/Serialization/ASTReaderDecl.cpp | 27 ++
clang/lib/Serialization/ASTReaderInternals.h | 124 +++++++++
clang/lib/Serialization/ASTWriter.cpp | 174 +++++++++++-
clang/lib/Serialization/ASTWriterDecl.cpp | 32 ++-
clang/unittests/Serialization/CMakeLists.txt | 1 +
.../Serialization/LoadSpecLazilyTest.cpp | 260 ++++++++++++++++++
15 files changed, 855 insertions(+), 13 deletions(-)
create mode 100644 clang/unittests/Serialization/LoadSpecLazilyTest.cpp
diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 582ed7c65f58ca..5f4f9a9a8d681e 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -152,6 +152,17 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
virtual bool
FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
+ /// Load all the external specializations for the Decl \param D if \param
+ /// OnlyPartial is false. Otherwise, load all the external **partial**
+ /// specializations for the \param D.
+ virtual void LoadExternalSpecializations(const Decl *D, bool OnlyPartial);
+
+ /// Load all the specializations for the Decl \param D with the same template
+ /// args specified by \param TemplateArgs.
+ virtual void
+ LoadExternalSpecializations(const Decl *D,
+ ArrayRef<TemplateArgument> TemplateArgs);
+
/// Ensures that the table of all visible declarations inside this
/// context is up to date.
///
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 3d1906d8699265..78bbbaf2d7b5c6 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
bool FindExternalVisibleDeclsByName(const DeclContext *DC,
DeclarationName Name) override;
+ void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+ void
+ LoadExternalSpecializations(const Decl *D,
+ ArrayRef<TemplateArgument> TemplateArgs) override;
+
/// Ensures that the table of all visible declarations inside this
/// context is up to date.
void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 3b14a0b8203315..cb3ed6c1ecbb7c 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -734,6 +734,9 @@ enum ASTRecordTypes {
/// Record code for Sema's vector of functions/blocks with effects to
/// be verified.
DECLS_WITH_EFFECTS_TO_VERIFY = 72,
+
+ /// Record code for updated specialization
+ UPDATE_SPECIALIZATION = 73,
};
/// Record types used within a source manager block.
@@ -1500,6 +1503,9 @@ enum DeclCode {
/// A HLSLBufferDecl record.
DECL_HLSL_BUFFER,
+ // A decls specilization record.
+ DECL_SPECIALIZATIONS,
+
/// An ImplicitConceptSpecializationDecl record.
DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 9c274adc59a207..6306d4f08e81fa 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -354,6 +354,9 @@ class ASTIdentifierLookupTrait;
/// The on-disk hash table(s) used for DeclContext name lookup.
struct DeclContextLookupTable;
+/// The on-disk hash table(s) used for specialization decls.
+struct LazySpecializationInfoLookupTable;
+
} // namespace reader
} // namespace serialization
@@ -632,20 +635,29 @@ class ASTReader
llvm::DenseMap<const DeclContext *,
serialization::reader::DeclContextLookupTable> Lookups;
+ /// Map from decls to specialized decls.
+ llvm::DenseMap<const Decl *,
+ serialization::reader::LazySpecializationInfoLookupTable>
+ SpecializationsLookups;
+
// Updates for visible decls can occur for other contexts than just the
// TU, and when we read those update records, the actual context may not
// be available yet, so have this pending map using the ID as a key. It
- // will be realized when the context is actually loaded.
- struct PendingVisibleUpdate {
+ // will be realized when the data is actually loaded.
+ struct UpdateData {
ModuleFile *Mod;
const unsigned char *Data;
};
- using DeclContextVisibleUpdates = SmallVector<PendingVisibleUpdate, 1>;
+ using DeclContextVisibleUpdates = SmallVector<UpdateData, 1>;
/// Updates to the visible declarations of declaration contexts that
/// haven't been loaded yet.
llvm::DenseMap<GlobalDeclID, DeclContextVisibleUpdates> PendingVisibleUpdates;
+ using SpecializationsUpdate = SmallVector<UpdateData, 1>;
+ llvm::DenseMap<GlobalDeclID, SpecializationsUpdate>
+ PendingSpecializationsUpdates;
+
/// The set of C++ or Objective-C classes that have forward
/// declarations that have not yet been linked to their definitions.
llvm::SmallPtrSet<Decl *, 4> PendingDefinitions;
@@ -678,6 +690,11 @@ class ASTReader
llvm::BitstreamCursor &Cursor,
uint64_t Offset, GlobalDeclID ID);
+ bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+ uint64_t Offset, Decl *D);
+ void AddSpecializations(const Decl *D, const unsigned char *Data,
+ ModuleFile &M);
+
/// A vector containing identifiers that have already been
/// loaded.
///
@@ -1419,6 +1436,11 @@ class ASTReader
const serialization::reader::DeclContextLookupTable *
getLoadedLookupTables(DeclContext *Primary) const;
+ /// Get the loaded specializations lookup tables for \p D,
+ /// if any.
+ serialization::reader::LazySpecializationInfoLookupTable *
+ getLoadedSpecializationsLookupTables(const Decl *D);
+
private:
struct ImportedModule {
ModuleFile *Mod;
@@ -2076,6 +2098,12 @@ class ASTReader
unsigned BlockID,
uint64_t *StartOfBlockOffset = nullptr);
+ void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+ void
+ LoadExternalSpecializations(const Decl *D,
+ ArrayRef<TemplateArgument> TemplateArgs) override;
+
/// Finds all the visible declarations with a given name.
/// The current implementation of this method just loads the entire
/// lookup table as unmaterialized references.
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index dc9fcd3c33726e..9da22f96130a8b 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -423,6 +423,12 @@ class ASTWriter : public ASTDeserializationListener,
/// Only meaningful for reduced BMI.
DeclUpdateMap DeclUpdatesFromGMF;
+ /// Mapping from decl templates and its new specialization in the
+ /// current TU.
+ using SpecializationUpdateMap =
+ llvm::MapVector<const NamedDecl *, SmallVector<const Decl *>>;
+ SpecializationUpdateMap SpecializationsUpdates;
+
using FirstLatestDeclMap = llvm::DenseMap<Decl *, Decl *>;
/// Map of first declarations from a chained PCH that point to the
@@ -572,6 +578,11 @@ class ASTWriter : public ASTDeserializationListener,
bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
+ void GenerateSpecializationInfoLookupTable(
+ const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+ llvm::SmallVectorImpl<char> &LookupTable);
+ uint64_t WriteSpecializationInfoLookupTable(
+ const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations);
void GenerateNameLookupTable(ASTContext &Context, const DeclContext *DC,
llvm::SmallVectorImpl<char> &LookupTable);
uint64_t WriteDeclContextLexicalBlock(ASTContext &Context,
@@ -587,6 +598,7 @@ class ASTWriter : public ASTDeserializationListener,
void WriteDeclAndTypes(ASTContext &Context);
void PrepareWritingSpecialDecls(Sema &SemaRef);
void WriteSpecialDeclRecords(Sema &SemaRef);
+ void WriteSpecializationsUpdates();
void WriteDeclUpdatesBlocks(ASTContext &Context,
RecordDataImpl &OffsetsRecord);
void WriteDeclContextVisibleUpdate(ASTContext &Context,
@@ -616,6 +628,8 @@ class ASTWriter : public ASTDeserializationListener,
unsigned DeclEnumAbbrev = 0;
unsigned DeclObjCIvarAbbrev = 0;
unsigned DeclCXXMethodAbbrev = 0;
+ unsigned DeclSpecializationsAbbrev = 0;
+
unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
unsigned DeclTemplateCXXMethodAbbrev = 0;
unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index d0e16c30881be2..a73af9c7785320 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -355,6 +355,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
bool OnlyPartial /*=false*/) const {
+ auto *ExternalSource = getASTContext().getExternalSource();
+ if (!ExternalSource)
+ return;
+
+ ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+ OnlyPartial);
+ return;
+
// Grab the most recent declaration to ensure we've loaded any lazy
// redeclarations of this template.
CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
@@ -374,6 +382,8 @@ void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
LazySpecializationInfo &LazySpecInfo) const {
+ llvm_unreachable("We don't use LazySpecializationInfo any more");
+
GlobalDeclID ID = LazySpecInfo.DeclID;
assert(ID.isValid() && "Loading already loaded specialization!");
// Note that we loaded the specialization.
@@ -384,6 +394,13 @@ Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
ArrayRef<TemplateArgument> Args, TemplateParameterList *TPL) const {
+ auto *ExternalSource = getASTContext().getExternalSource();
+ if (!ExternalSource)
+ return;
+
+ ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(), Args);
+ return;
+
CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
if (auto *Specs = CommonBasePtr->LazySpecializations) {
unsigned Hash = TemplateArgumentList::ComputeODRHash(Args);
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index a5b6f80bde694c..122014bfeb2321 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -98,6 +98,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
return false;
}
+void ExternalASTSource::LoadExternalSpecializations(const Decl *D, bool) {}
+
+void ExternalASTSource::LoadExternalSpecializations(
+ const Decl *D, ArrayRef<TemplateArgument>) {}
+
void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index cd44483b5cbe04..f39463712ce81f 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,18 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
return AnyDeclsFound;
}
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+ const Decl *D, bool OnlyPartial) {
+ for (size_t i = 0; i < Sources.size(); ++i)
+ Sources[i]->LoadExternalSpecializations(D, OnlyPartial);
+}
+
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+ const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+ for (size_t i = 0; i < Sources.size(); ++i)
+ Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+}
+
void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
for(size_t i = 0; i < Sources.size(); ++i)
Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 8bba98ef33302b..75d81a25dd3ac3 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1283,6 +1283,43 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
}
}
+ModuleFile *
+LazySpecializationInfoLookupTrait::ReadFileRef(const unsigned char *&d) {
+ using namespace llvm::support;
+
+ uint32_t ModuleFileID =
+ endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+ return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+LazySpecializationInfoLookupTrait::internal_key_type
+LazySpecializationInfoLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+ using namespace llvm::support;
+ return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+LazySpecializationInfoLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+ return readULEBKeyDataLength(d);
+}
+
+void LazySpecializationInfoLookupTrait::ReadDataInto(internal_key_type,
+ const unsigned char *d,
+ unsigned DataLen,
+ data_type_builder &Val) {
+ using namespace llvm::support;
+
+ for (unsigned NumDecls =
+ DataLen / serialization::reader::LazySpecializationInfo::Length;
+ NumDecls; --NumDecls) {
+ LocalDeclID LocalID =
+ LocalDeclID::get(Reader, F, endian::readNext<DeclID, llvm::endianness::little, unaligned>(d));
+ const bool IsPartial =
+ endian::readNext<bool, llvm::endianness::little, unaligned>(d);
+ Val.insert({Reader.getGlobalDeclID(F, LocalID), IsPartial});
+ }
+}
+
bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
BitstreamCursor &Cursor,
uint64_t Offset,
@@ -1367,7 +1404,49 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
// We can't safely determine the primary context yet, so delay attaching the
// lookup table until we're done with recursive deserialization.
auto *Data = (const unsigned char*)Blob.data();
- PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&M, Data});
+ PendingVisibleUpdates[ID].push_back(UpdateData{&M, Data});
+ return false;
+}
+
+void ASTReader::AddSpecializations(const Decl *D, const unsigned char *Data,
+ ModuleFile &M) {
+ D = D->getCanonicalDecl();
+ SpecializationsLookups[D].Table.add(
+ &M, Data, reader::LazySpecializationInfoLookupTrait(*this, M));
+}
+
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+ uint64_t Offset, Decl *D) {
+ assert(Offset != 0);
+
+ SavedStreamPosition SavedPosition(Cursor);
+ if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+ Error(std::move(Err));
+ return true;
+ }
+
+ RecordData Record;
+ StringRef Blob;
+ Expected<unsigned> MaybeCode = Cursor.ReadCode();
+ if (!MaybeCode) {
+ Error(MaybeCode.takeError());
+ return true;
+ }
+ unsigned Code = MaybeCode.get();
+
+ Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+ if (!MaybeRecCode) {
+ Error(MaybeRecCode.takeError());
+ return true;
+ }
+ unsigned RecCode = MaybeRecCode.get();
+ if (RecCode != DECL_SPECIALIZATIONS) {
+ Error("Expected decl specs block");
+ return true;
+ }
+
+ auto *Data = (const unsigned char *)Blob.data();
+ AddSpecializations(D, Data, M);
return false;
}
@@ -3454,7 +3533,20 @@ llvm::Error ASTReader::ReadASTBlock(ModuleFile &F,
unsigned Idx = 0;
GlobalDeclID ID = ReadDeclID(F, Record, Idx);
auto *Data = (const unsigned char*)Blob.data();
- PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&F, Data});
+ PendingVisibleUpdates[ID].push_back(UpdateData{&F, Data});
+ // If we've already loaded the decl, perform the updates when we finish
+ // loading this block.
+ if (Decl *D = GetExistingDecl(ID))
+ PendingUpdateRecords.push_back(
+ PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+ break;
+ }
+
+ case UPDATE_SPECIALIZATION: {
+ unsigned Idx = 0;
+ GlobalDeclID ID = ReadDeclID(F, Record, Idx);
+ auto *Data = (const unsigned char *)Blob.data();
+ PendingSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
// If we've already loaded the decl, perform the updates when we finish
// loading this block.
if (Decl *D = GetExistingDecl(ID))
@@ -8037,6 +8129,48 @@ Stmt *ASTReader::GetExternalDeclStmt(uint64_t Offset) {
return ReadStmtFromStream(*Loc.F);
}
+void ASTReader::LoadExternalSpecializations(const Decl *D, bool OnlyPartial) {
+ assert(D);
+
+ auto It = SpecializationsLookups.find(D);
+ if (It == SpecializationsLookups.end())
+ return;
+
+ // Get Decl may violate the iterator from SpecializationsLookups so we store
+ // the DeclIDs in ahead.
+ llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+ It->second.Table.findAll();
+
+ // Since we've loaded all the specializations, we can erase it from
+ // the lookup table.
+ if (!OnlyPartial)
+ SpecializationsLookups.erase(It);
+
+ Deserializing LookupResults(this);
+ for (auto &Info : Infos)
+ if (!OnlyPartial || Info.IsPartial)
+ GetDecl(Info.ID);
+}
+
+void ASTReader::LoadExternalSpecializations(
+ const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+ assert(D);
+
+ auto It = SpecializationsLookups.find(D);
+ if (It == SpecializationsLookups.end())
+ return;
+
+ Deserializing LookupResults(this);
+ auto HashValue = TemplateArgumentList::ComputeODRHash(TemplateArgs);
+
+ // Get Decl may violate the iterator from SpecializationsLookups
+ llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+ It->second.Table.find(HashValue);
+
+ for (auto &Info : Infos)
+ GetDecl(Info.ID);
+}
+
void ASTReader::FindExternalLexicalDecls(
const DeclContext *DC, llvm::function_ref<bool(Decl::Kind)> IsKindWeWant,
SmallVectorImpl<Decl *> &Decls) {
@@ -8216,6 +8350,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
return I == Lookups.end() ? nullptr : &I->second;
}
+serialization::reader::LazySpecializationInfoLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D) {
+ assert(D->isCanonicalDecl());
+ auto I = SpecializationsLookups.find(D);
+ return I == SpecializationsLookups.end() ? nullptr : &I->second;
+}
+
/// Under non-PCH compilation the consumer receives the objc methods
/// before receiving the implementation, and codegen depends on this.
/// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 6816bbcd45dcbe..87c491a977ffeb 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -342,6 +342,9 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
static void markIncompleteDeclChainImpl(Redeclarable<DeclT> *D);
static void markIncompleteDeclChainImpl(...);
+ void ReadSpecializations(ModuleFile &M, Decl *D,
+ llvm::BitstreamCursor &DeclsCursor);
+
void ReadFunctionDefinition(FunctionDecl *FD);
void Visit(Decl *D);
@@ -2430,6 +2433,14 @@ void ASTDeclReader::VisitImplicitConceptSpecializationDecl(
void ASTDeclReader::VisitRequiresExprBodyDecl(RequiresExprBodyDecl *D) {
}
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+ llvm::BitstreamCursor &DeclsCursor) {
+ uint64_t Offset = ReadLocalOffset();
+ bool Failed = Reader.ReadSpecializations(M, DeclsCursor, Offset, D);
+ (void)Failed;
+ assert(!Failed);
+}
+
RedeclarableResult
ASTDeclReader::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
RedeclarableResult Redecl = VisitRedeclarable(D);
@@ -2471,6 +2482,7 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
SmallVector<LazySpecializationInfo, 32> SpecIDs;
readDeclIDList(SpecIDs);
ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+ ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
}
if (D->getTemplatedDecl()->TemplateOrInstantiation) {
@@ -2499,6 +2511,7 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
SmallVector<LazySpecializationInfo, 32> SpecIDs;
readDeclIDList(SpecIDs);
ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+ ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
}
}
@@ -2600,6 +2613,7 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
SmallVector<LazySpecializationInfo, 32> SpecIDs;
readDeclIDList(SpecIDs);
ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+ ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
}
}
@@ -3889,6 +3903,7 @@ Decl *ASTReader::ReadDeclRecord(GlobalDeclID ID) {
switch ((DeclCode)MaybeDeclCode.get()) {
case DECL_CONTEXT_LEXICAL:
case DECL_CONTEXT_VISIBLE:
+ case DECL_SPECIALIZATIONS:
llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
case DECL_TYPEDEF:
D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4352,12 +4367,14 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
assert((PendingLazySpecializationIDs.empty() || isa<ClassTemplateDecl>(D) ||
isa<FunctionTemplateDecl, VarTemplateDecl>(D)) &&
"Must not have pending specializations");
+ /*
if (auto *CTD = dyn_cast<ClassTemplateDecl>(D))
ASTDeclReader::AddLazySpecializations(CTD, PendingLazySpecializationIDs);
else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
ASTDeclReader::AddLazySpecializations(FTD, PendingLazySpecializationIDs);
else if (auto *VTD = dyn_cast<VarTemplateDecl>(D))
ASTDeclReader::AddLazySpecializations(VTD, PendingLazySpecializationIDs);
+ */
PendingLazySpecializationIDs.clear();
// Load the pending visible updates for this decl context, if it has any.
@@ -4383,6 +4400,16 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
FunctionToLambdasMap.erase(IT);
}
}
+
+ // Load the pending specializations update for this decl, if it has any.
+ if (auto I = PendingSpecializationsUpdates.find(ID);
+ I != PendingSpecializationsUpdates.end()) {
+ auto SpecializationUpdates = std::move(I->second);
+ PendingSpecializationsUpdates.erase(I);
+
+ for (const auto &Update : SpecializationUpdates)
+ AddSpecializations(D, Update.Data, *Update.Mod);
+ }
}
void ASTReader::loadPendingDeclChain(Decl *FirstLocal, uint64_t LocalOffset) {
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 4f7e6f4b2741b7..b921d8d174c3a2 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,102 @@ struct DeclContextLookupTable {
MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
};
+struct LazySpecializationInfo {
+ // The Decl ID for the specialization.
+ GlobalDeclID ID;
+ // Whether or not this specialization is partial.
+ bool IsPartial;
+
+ bool operator==(const LazySpecializationInfo &Other) {
+ assert(ID != Other.ID || IsPartial == Other.IsPartial);
+ return ID == Other.ID;
+ }
+
+ // Records the size record in OnDiskHashTable.
+ // sizeof() may return 8 due to align requirements.
+ static constexpr unsigned Length = sizeof(DeclID) + sizeof(IsPartial);
+};
+
+/// Class that performs lookup to specialized decls.
+class LazySpecializationInfoLookupTrait {
+ ASTReader &Reader;
+ ModuleFile &F;
+
+public:
+ // Maximum number of lookup tables we allow before condensing the tables.
+ static const int MaxTables = 4;
+
+ /// The lookup result is a list of global declaration IDs.
+ using data_type = SmallVector<LazySpecializationInfo, 4>;
+
+ struct data_type_builder {
+ data_type &Data;
+ llvm::DenseSet<LazySpecializationInfo> Found;
+
+ data_type_builder(data_type &D) : Data(D) {}
+
+ void insert(LazySpecializationInfo Info) {
+ // Just use a linear scan unless we have more than a few IDs.
+ if (Found.empty() && !Data.empty()) {
+ if (Data.size() <= 4) {
+ for (auto I : Found)
+ if (I == Info)
+ return;
+ Data.push_back(Info);
+ return;
+ }
+
+ // Switch to tracking found IDs in the set.
+ Found.insert(Data.begin(), Data.end());
+ }
+
+ if (Found.insert(Info).second)
+ Data.push_back(Info);
+ }
+ };
+ using hash_value_type = unsigned;
+ using offset_type = unsigned;
+ using file_type = ModuleFile *;
+
+ using external_key_type = unsigned;
+ using internal_key_type = unsigned;
+
+ explicit LazySpecializationInfoLookupTrait(ASTReader &Reader, ModuleFile &F)
+ : Reader(Reader), F(F) {}
+
+ static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+ return a == b;
+ }
+
+ static hash_value_type ComputeHash(const internal_key_type &Key) {
+ return Key;
+ }
+
+ static internal_key_type GetInternalKey(const external_key_type &Name) {
+ return Name;
+ }
+
+ static std::pair<unsigned, unsigned>
+ ReadKeyDataLength(const unsigned char *&d);
+
+ internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+ void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+ data_type_builder &Val);
+
+ static void MergeDataInto(const data_type &From, data_type_builder &To) {
+ To.Data.reserve(To.Data.size() + From.size());
+ for (LazySpecializationInfo Info : From)
+ To.insert(Info);
+ }
+
+ file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct LazySpecializationInfoLookupTable {
+ MultiOnDiskHashTable<LazySpecializationInfoLookupTrait> Table;
+};
+
/// Base class for the trait describing the on-disk hash table for the
/// identifiers in an AST file.
///
@@ -288,4 +384,32 @@ using HeaderFileInfoLookupTable =
} // namespace clang
+namespace llvm {
+// ID is unique in LazySpecializationInfo, it is redundant to calculate
+// IsPartial.
+template <>
+struct DenseMapInfo<clang::serialization::reader::LazySpecializationInfo> {
+ using LazySpecializationInfo =
+ clang::serialization::reader::LazySpecializationInfo;
+ using Wrapped = DenseMapInfo<clang::serialization::DeclID>;
+
+ static inline LazySpecializationInfo getEmptyKey() {
+ return {(clang::GlobalDeclID)Wrapped::getEmptyKey(), false};
+ }
+
+ static inline LazySpecializationInfo getTombstoneKey() {
+ return {(clang::GlobalDeclID)Wrapped::getTombstoneKey(), false};
+ }
+
+ static unsigned getHashValue(const LazySpecializationInfo &Key) {
+ return Wrapped::getHashValue(Key.ID.getRawValue());
+ }
+
+ static bool isEqual(const LazySpecializationInfo &LHS,
+ const LazySpecializationInfo &RHS) {
+ return LHS.ID == RHS.ID;
+ }
+};
+} // end namespace llvm
+
#endif // LLVM_CLANG_LIB_SERIALIZATION_ASTREADERINTERNALS_H
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 455e8f0e749435..adfe9fc8b369d6 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -4106,6 +4106,155 @@ class ASTDeclContextNameLookupTrait {
} // namespace
+namespace {
+class LazySpecializationInfoLookupTrait {
+ ASTWriter &Writer;
+ llvm::SmallVector<serialization::reader::LazySpecializationInfo, 64> Specs;
+
+public:
+ using key_type = unsigned;
+ using key_type_ref = key_type;
+
+ /// A start and end index into Specs, representing a sequence of decls.
+ using data_type = std::pair<unsigned, unsigned>;
+ using data_type_ref = const data_type &;
+
+ using hash_value_type = unsigned;
+ using offset_type = unsigned;
+
+ explicit LazySpecializationInfoLookupTrait(ASTWriter &Writer)
+ : Writer(Writer) {}
+
+ template <typename Col> data_type getData(Col &&C) {
+ unsigned Start = Specs.size();
+ for (auto *D : C) {
+ bool IsPartial = isa<ClassTemplatePartialSpecializationDecl,
+ VarTemplatePartialSpecializationDecl>(D);
+ NamedDecl *ND = getDeclForLocalLookup(
+ Writer.getLangOpts(), const_cast<NamedDecl *>(D));
+ Specs.push_back({GlobalDeclID(Writer.GetDeclRef(ND).getRawValue()),
+ IsPartial});
+ }
+ return std::make_pair(Start, Specs.size());
+ }
+
+ data_type ImportData(
+ const reader::LazySpecializationInfoLookupTrait::data_type &FromReader) {
+ unsigned Start = Specs.size();
+ for (auto ID : FromReader)
+ Specs.push_back(ID);
+ return std::make_pair(Start, Specs.size());
+ }
+
+ static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+ hash_value_type ComputeHash(key_type Name) { return Name; }
+
+ void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+ assert(Writer.hasChain() &&
+ "have reference to loaded module file but no chain?");
+
+ using namespace llvm::support;
+ endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+ llvm::endianness::little);
+ }
+
+ std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+ key_type HashValue,
+ data_type_ref Lookup) {
+ // 4 bytes for each slot.
+ unsigned KeyLen = 4;
+ unsigned DataLen = serialization::reader::LazySpecializationInfo::Length *
+ (Lookup.second - Lookup.first);
+
+ return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+ }
+
+ void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+ using namespace llvm::support;
+
+ endian::Writer LE(Out, llvm::endianness::little);
+ LE.write<uint32_t>(HashValue);
+ }
+
+ void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+ unsigned DataLen) {
+ using namespace llvm::support;
+
+ endian::Writer LE(Out, llvm::endianness::little);
+ uint64_t Start = Out.tell();
+ (void)Start;
+ for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I) {
+ LE.write<DeclID>(Specs[I].ID.getRawValue());
+ LE.write<bool>(Specs[I].IsPartial);
+ }
+ assert(Out.tell() - Start == DataLen && "Data length is wrong");
+ }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+ ArrayRef<TemplateArgument> Args;
+ if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+ Args = CTSD->getTemplateArgs().asArray();
+ else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+ Args = VTSD->getTemplateArgs().asArray();
+ else if (auto *FD = dyn_cast<FunctionDecl>(Spec))
+ Args = FD->getTemplateSpecializationArgs()->asArray();
+ else
+ llvm_unreachable("New Specialization Kind?");
+
+ return TemplateArgumentList::ComputeODRHash(Args);
+}
+} // namespace
+
+void ASTWriter::GenerateSpecializationInfoLookupTable(
+ const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+ llvm::SmallVectorImpl<char> &LookupTable) {
+ assert(D->isFirstDecl());
+
+ // Create the on-disk hash table representation.
+ MultiOnDiskHashTableGenerator<reader::LazySpecializationInfoLookupTrait,
+ LazySpecializationInfoLookupTrait>
+ Generator;
+ LazySpecializationInfoLookupTrait Trait(*this);
+
+ llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+ SpecializationMaps;
+
+ for (auto *Specialization : Specializations) {
+ unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+ auto Iter = SpecializationMaps.find(HashedValue);
+ if (Iter == SpecializationMaps.end())
+ Iter = SpecializationMaps
+ .try_emplace(HashedValue,
+ llvm::SmallVector<const NamedDecl *, 4>())
+ .first;
+
+ Iter->second.push_back(cast<NamedDecl>(Specialization));
+ }
+
+ for (auto Iter : SpecializationMaps)
+ Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+ auto *Lookups =
+ Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
+ Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+}
+
+uint64_t ASTWriter::WriteSpecializationInfoLookupTable(
+ const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations) {
+
+ llvm::SmallString<4096> LookupTable;
+ GenerateSpecializationInfoLookupTable(D, Specializations, LookupTable);
+
+ uint64_t Offset = Stream.GetCurrentBitNo();
+ RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
+ Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
+
+ return Offset;
+}
+
bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
DeclContext *DC) {
return Result.hasExternalDecls() &&
@@ -5650,7 +5799,7 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
// Keep writing types, declarations, and declaration update records
// until we've emitted all of them.
RecordData DeclUpdatesOffsetsRecord;
- Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+ Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/6);
DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
WriteTypeAbbrevs();
WriteDeclAbbrevs();
@@ -5724,6 +5873,9 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
FunctionToLambdaMapAbbrev);
}
+ if (!SpecializationsUpdates.empty())
+ WriteSpecializationsUpdates();
+
const TranslationUnitDecl *TU = Context.getTranslationUnitDecl();
// Create a lexical update block containing all of the declarations in the
// translation unit that do not come from other AST files.
@@ -5767,6 +5919,26 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
WriteDeclContextVisibleUpdate(Context, DC);
}
+void ASTWriter::WriteSpecializationsUpdates() {
+ auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
+ Abv->Add(llvm::BitCodeAbbrevOp(UPDATE_SPECIALIZATION));
+ Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::VBR, 6));
+ Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::Blob));
+ auto UpdateSpecializationAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+ for (auto &SpecializationUpdate : SpecializationsUpdates) {
+ const NamedDecl *D = SpecializationUpdate.first;
+
+ llvm::SmallString<4096> LookupTable;
+ GenerateSpecializationInfoLookupTable(D, SpecializationUpdate.second,
+ LookupTable);
+
+ // Write the lookup table
+ RecordData::value_type Record[] = {UPDATE_SPECIALIZATION, getDeclID(D).getRawValue()};
+ Stream.EmitRecordWithBlob(UpdateSpecializationAbbrev, Record, LookupTable);
+ }
+}
+
void ASTWriter::WriteDeclUpdatesBlocks(ASTContext &Context,
RecordDataImpl &OffsetsRecord) {
if (DeclUpdates.empty())
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index 71a9bd1aefade6..14c2f23700e8c3 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -207,15 +207,17 @@ namespace clang {
/// file that provides a declaration of D. We store the DeclId and an
/// ODRHash of the template arguments of D which should provide enough
/// information to load D only if the template instantiator needs it.
- void AddFirstSpecializationDeclFromEachModule(const Decl *D,
- bool IncludeLocal) {
- assert(isa<ClassTemplateSpecializationDecl>(D) ||
- isa<VarTemplateSpecializationDecl>(D) ||
- isa<FunctionDecl>(D) && "Must not be called with other decls");
+ void AddFirstSpecializationDeclFromEachModule(
+ const Decl *D, llvm::SmallVectorImpl<const Decl *> &SpecsInMap) {
+ assert((isa<ClassTemplateSpecializationDecl>(D) ||
+ isa<VarTemplateSpecializationDecl>(D) || isa<FunctionDecl>(D)) &&
+ "Must not be called with other decls");
llvm::MapVector<ModuleFile *, const Decl *> Firsts;
- CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+ CollectFirstDeclFromEachModule(D, /*IncludeLocal*/ true, Firsts);
for (const auto &F : Firsts) {
+ SpecsInMap.push_back(F.second);
+
Record.AddDeclRef(F.second);
ArrayRef<TemplateArgument> Args;
if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
@@ -282,10 +284,15 @@ namespace clang {
for (auto &Entry : getPartialSpecializations(Common))
Specs.push_back(getSpecializationDecl(Entry));
+ llvm::SmallVector<const Decl *, 16> SpecsInOnDiskMap = Specs;
+
for (auto *D : Specs) {
assert(D->isCanonicalDecl() && "non-canonical decl in set");
- AddFirstSpecializationDeclFromEachModule(D, /*IncludeLocal*/ true);
+ AddFirstSpecializationDeclFromEachModule(D, SpecsInOnDiskMap);
}
+
+ // We don't need to insert LazySpecializations to SpecsInOnDiskMap,
+ // since we'll handle that in GenerateSpecializationInfoLookupTable.
for (auto &SpecInfo : LazySpecializations) {
Record.push_back(SpecInfo.DeclID.getRawValue());
Record.push_back(SpecInfo.ODRHash);
@@ -298,6 +305,9 @@ namespace clang {
assert((Record.size() - I - 1) % 3 == 0 &&
"Must be divisible by LazySpecializationInfo count!");
Record[I] = (Record.size() - I - 1) / 3;
+
+ Record.AddOffset(
+ Writer.WriteSpecializationInfoLookupTable(D, SpecsInOnDiskMap));
}
/// Ensure that this template specialization is associated with the specified
@@ -320,6 +330,9 @@ namespace clang {
Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION, Specialization));
+
+ Writer.SpecializationsUpdates[cast<NamedDecl>(Template)].push_back(
+ cast<NamedDecl>(Specialization));
}
};
}
@@ -2824,6 +2837,11 @@ void ASTWriter::WriteDeclAbbrevs() {
Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+ Abv = std::make_shared<BitCodeAbbrev>();
+ Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+ Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+ DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
}
/// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index e7eebd0cb98239..e7005b5d511eba 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -11,6 +11,7 @@ add_clang_unittest(SerializationTests
ModuleCacheTest.cpp
NoCommentsTest.cpp
PreambleInNamedModulesTest.cpp
+ LoadSpecLazilyTest.cpp
SourceLocationEncodingTest.cpp
VarDeclConstantInitTest.cpp
)
diff --git a/clang/unittests/Serialization/LoadSpecLazilyTest.cpp b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
new file mode 100644
index 00000000000000..76e3ccae3d3c3d
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
@@ -0,0 +1,260 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+ void SetUp() override {
+ ASSERT_FALSE(
+ sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+ }
+
+ void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+ SmallString<256> TestDir;
+
+ void addFile(StringRef Path, StringRef Contents) {
+ ASSERT_FALSE(sys::path::is_absolute(Path));
+
+ SmallString<256> AbsPath(TestDir);
+ sys::path::append(AbsPath, Path);
+
+ ASSERT_FALSE(
+ sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+ std::error_code EC;
+ llvm::raw_fd_ostream OS(AbsPath, EC);
+ ASSERT_FALSE(EC);
+ OS << Contents;
+ }
+
+ std::string GenerateModuleInterface(StringRef ModuleName,
+ StringRef Contents) {
+ std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+ addFile(FileName, Contents);
+
+ IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+ CompilerInstance::createDiagnostics(new DiagnosticOptions());
+ CreateInvocationOptions CIOpts;
+ CIOpts.Diags = Diags;
+ CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+ std::string CacheBMIPath =
+ llvm::Twine(TestDir + "/" + ModuleName + ".pcm").str();
+ std::string PrebuiltModulePath =
+ "-fprebuilt-module-path=" + TestDir.str().str();
+ const char *Args[] = {"clang++",
+ "-std=c++20",
+ "--precompile",
+ PrebuiltModulePath.c_str(),
+ "-working-directory",
+ TestDir.c_str(),
+ "-I",
+ TestDir.c_str(),
+ FileName.c_str(),
+ "-o",
+ CacheBMIPath.c_str()};
+ std::shared_ptr<CompilerInvocation> Invocation =
+ createInvocation(Args, CIOpts);
+ EXPECT_TRUE(Invocation);
+
+ CompilerInstance Instance;
+ Instance.setDiagnostics(Diags.get());
+ Instance.setInvocation(Invocation);
+ Instance.getFrontendOpts().OutputFile = CacheBMIPath;
+ GenerateModuleInterfaceAction Action;
+ EXPECT_TRUE(Instance.ExecuteAction(Action));
+ EXPECT_FALSE(Diags->hasErrorOccurred());
+
+ return CacheBMIPath;
+ }
+};
+
+enum class CheckingMode { Forbidden, Required };
+
+class DeclsReaderListener : public ASTDeserializationListener {
+ StringRef SpeficiedName;
+ CheckingMode Mode;
+
+ bool ReadedSpecifiedName = false;
+
+public:
+ void DeclRead(GlobalDeclID ID, const Decl *D) override {
+ auto *ND = dyn_cast<NamedDecl>(D);
+ if (!ND)
+ return;
+
+ ReadedSpecifiedName |= ND->getName().contains(SpeficiedName);
+ if (Mode == CheckingMode::Forbidden) {
+ EXPECT_FALSE(ReadedSpecifiedName);
+ }
+ }
+
+ DeclsReaderListener(StringRef SpeficiedName, CheckingMode Mode)
+ : SpeficiedName(SpeficiedName), Mode(Mode) {}
+
+ ~DeclsReaderListener() {
+ if (Mode == CheckingMode::Required) {
+ EXPECT_TRUE(ReadedSpecifiedName);
+ }
+ }
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+ DeclsReaderListener Listener;
+
+public:
+ LoadSpecLazilyConsumer(StringRef SpecifiedName, CheckingMode Mode)
+ : Listener(SpecifiedName, Mode) {}
+
+ ASTDeserializationListener *GetASTDeserializationListener() override {
+ return &Listener;
+ }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+ StringRef SpecifiedName;
+ CheckingMode Mode;
+
+public:
+ std::unique_ptr<ASTConsumer>
+ CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+ return std::make_unique<LoadSpecLazilyConsumer>(SpecifiedName, Mode);
+ }
+
+ CheckLoadSpecLazilyAction(StringRef SpecifiedName, CheckingMode Mode)
+ : SpecifiedName(SpecifiedName), Mode(Mode) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+ GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+export class ShouldNotBeLoaded {};
+export class Temp {
+ A<ShouldNotBeLoaded> AS;
+};
+ )cpp");
+
+ const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+ )cpp";
+ std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+ EXPECT_TRUE(
+ runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+ "ShouldNotBeLoaded", CheckingMode::Forbidden),
+ test_file_contents,
+ {
+ "-std=c++20",
+ DepArg.c_str(),
+ "-I",
+ TestDir.c_str(),
+ },
+ "test.cpp"));
+}
+
+TEST_F(LoadSpecLazilyTest, ChainedTest) {
+ GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+ )cpp");
+
+ GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldNotBeLoaded {};
+export class Temp {
+ A<ShouldNotBeLoaded> AS;
+};
+ )cpp");
+
+ const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+ )cpp";
+ std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+ EXPECT_TRUE(
+ runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+ "ShouldNotBeLoaded", CheckingMode::Forbidden),
+ test_file_contents,
+ {
+ "-std=c++20",
+ DepArg.c_str(),
+ "-I",
+ TestDir.c_str(),
+ },
+ "test.cpp"));
+}
+
+/// Test that we won't crash due to we may invalidate the lazy specialization
+/// lookup table during the loading process.
+TEST_F(LoadSpecLazilyTest, ChainedTest2) {
+ GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class B {};
+
+export class C {
+ A<B> D;
+};
+ )cpp");
+
+ GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class MayBeLoaded {};
+
+export class Temp {
+ A<MayBeLoaded> AS;
+};
+
+export class ExportedClass {};
+
+export template<> class A<ExportedClass> {
+ A<MayBeLoaded> AS;
+ A<B> AB;
+};
+ )cpp");
+
+ const char *test_file_contents = R"cpp(
+import N;
+Temp T;
+A<ExportedClass> a;
+ )cpp";
+ std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+ EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+ "MayBeLoaded", CheckingMode::Required),
+ test_file_contents,
+ {
+ "-std=c++20",
+ DepArg.c_str(),
+ "-I",
+ TestDir.c_str(),
+ },
+ "test.cpp"));
+}
+
+} // namespace
More information about the llvm-branch-commits
mailing list