[llvm-branch-commits] [llvm] VocabStorage (PR #158376)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Sep 12 15:07:40 PDT 2025
https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/158376
None
>From 81a84b27f4b2aeaf6ca1421b2abb2a960c4e7a50 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Fri, 12 Sep 2025 22:06:44 +0000
Subject: [PATCH] VocabStorage
---
llvm/include/llvm/Analysis/IR2Vec.h | 145 +++++++--
llvm/lib/Analysis/IR2Vec.cpp | 230 +++++++++----
llvm/lib/Analysis/InlineAdvisor.cpp | 2 +-
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 6 +-
.../FunctionPropertiesAnalysisTest.cpp | 8 +-
llvm/unittests/Analysis/IR2VecTest.cpp | 301 ++++++++++++++++--
6 files changed, 570 insertions(+), 122 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 4a6db5d895a62..7d51a7320d194 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -45,6 +45,7 @@
#include "llvm/Support/JSON.h"
#include <array>
#include <map>
+#include <optional>
namespace llvm {
@@ -144,6 +145,73 @@ struct Embedding {
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
+/// Generic storage class for section-based vocabularies.
+/// VocabStorage provides a generic foundation for storing and accessing
+/// embeddings organized into sections.
+class VocabStorage {
+private:
+ /// Section-based storage
+ std::vector<std::vector<Embedding>> Sections;
+
+ size_t TotalSize = 0;
+ unsigned Dimension = 0;
+
+public:
+ /// Default constructor creates empty storage (invalid state)
+ VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
+
+ /// Create a VocabStorage with pre-organized section data
+ VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
+
+ VocabStorage(VocabStorage &&) = default;
+ VocabStorage &operator=(VocabStorage &&Other);
+
+ VocabStorage(const VocabStorage &) = delete;
+ VocabStorage &operator=(const VocabStorage &) = delete;
+
+ /// Get total number of entries across all sections
+ size_t size() const { return TotalSize; }
+
+ /// Get number of sections
+ unsigned getNumSections() const {
+ return static_cast<unsigned>(Sections.size());
+ }
+
+ /// Section-based access: Storage[sectionId][localIndex]
+ const std::vector<Embedding> &operator[](unsigned SectionId) const {
+ assert(SectionId < Sections.size() && "Invalid section ID");
+ return Sections[SectionId];
+ }
+
+ /// Get vocabulary dimension
+ unsigned getDimension() const { return Dimension; }
+
+ /// Check if vocabulary is valid (has data)
+ bool isValid() const { return TotalSize > 0; }
+
+ /// Iterator support for section-based access
+ class const_iterator {
+ const VocabStorage *Storage;
+ unsigned SectionId;
+ size_t LocalIndex;
+
+ public:
+ const_iterator(const VocabStorage *Storage, unsigned SectionId,
+ size_t LocalIndex)
+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
+
+ LLVM_ABI const Embedding &operator*() const;
+ LLVM_ABI const_iterator &operator++();
+ LLVM_ABI bool operator==(const const_iterator &Other) const;
+ LLVM_ABI bool operator!=(const const_iterator &Other) const;
+ };
+
+ const_iterator begin() const { return const_iterator(this, 0, 0); }
+ const_iterator end() const {
+ return const_iterator(this, getNumSections(), 0);
+ }
+};
+
/// Class for storing and accessing the IR2Vec vocabulary.
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
/// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
- // Vocabulary Slot Layout:
+ // Vocabulary Layout:
// +----------------+------------------------------------------------------+
// | Entity Type | Index Range |
// +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
// Note: "Similar" LLVM Types are grouped/canonicalized together.
// Operands include Comparison predicates (ICmp/FCmp).
// This can be extended to include other specializations in future.
- using VocabVector = std::vector<ir2vec::Embedding>;
- VocabVector Vocab;
+ enum class Section : unsigned {
+ Opcodes = 0,
+ CanonicalTypes = 1,
+ Operands = 2,
+ Predicates = 3,
+ MaxSections
+ };
+
+ // Use section-based storage for better organization and efficiency
+ VocabStorage Storage;
static constexpr unsigned NumICmpPredicates =
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,9 +304,18 @@ class Vocabulary {
NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
- LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
+ LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+
+ Vocabulary(const Vocabulary &) = delete;
+ Vocabulary &operator=(const Vocabulary &) = delete;
+
+ Vocabulary(Vocabulary &&) = default;
+ Vocabulary &operator=(Vocabulary &&Other);
+
+ LLVM_ABI bool isValid() const {
+ return Storage.size() == NumCanonicalEntries;
+ }
- LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
LLVM_ABI unsigned getDimension() const;
/// Total number of entries (opcodes + canonicalized types + operand kinds +
/// predicates)
@@ -251,12 +336,11 @@ class Vocabulary {
/// Function to get vocabulary key for a given predicate
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
- /// Functions to return the slot index or position of a given Opcode, TypeID,
- /// or OperandKind in the vocabulary.
- LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
- LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
- LLVM_ABI static unsigned getSlotIndex(const Value &Op);
- LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
+ /// Functions to return flat index
+ LLVM_ABI static unsigned getIndex(unsigned Opcode);
+ LLVM_ABI static unsigned getIndex(Type::TypeID TypeID);
+ LLVM_ABI static unsigned getIndex(const Value &Op);
+ LLVM_ABI static unsigned getIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -265,26 +349,21 @@ class Vocabulary {
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
- using const_iterator = VocabVector::const_iterator;
+ using const_iterator = VocabStorage::const_iterator;
+
const_iterator begin() const {
assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.begin();
+ return Storage.begin();
}
- const_iterator cbegin() const {
- assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.cbegin();
- }
+ const_iterator cbegin() const { return begin(); }
const_iterator end() const {
assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.end();
+ return Storage.end();
}
- const_iterator cend() const {
- assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.cend();
- }
+ const_iterator cend() const { return end(); }
/// Returns the string key for a given index position in the vocabulary.
/// This is useful for debugging or printing the vocabulary. Do not use this
@@ -292,7 +371,7 @@ class Vocabulary {
LLVM_ABI static StringRef getStringKey(unsigned Pos);
/// Create a dummy vocabulary for testing purposes.
- LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
+ LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
@@ -301,12 +380,16 @@ class Vocabulary {
constexpr static unsigned NumCanonicalEntries =
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
- // Base offsets for slot layout to simplify index computation
+ // Base offsets for flat index computation
constexpr static unsigned OperandBaseOffset =
MaxOpcodes + MaxCanonicalTypeIDs;
constexpr static unsigned PredicateBaseOffset =
OperandBaseOffset + MaxOperandKinds;
+ /// Functions for predicate index calculations
+ static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
+ static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
+
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
@@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
/// its corresponding embedding.
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
- using VocabVector = std::vector<ir2vec::Embedding>;
using VocabMap = std::map<std::string, ir2vec::Embedding>;
- VocabMap OpcVocab, TypeVocab, ArgVocab;
- VocabVector Vocab;
+ std::optional<ir2vec::VocabStorage> Vocab;
- Error readVocabulary();
+ Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
+ VocabMap &ArgVocab);
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
VocabMap &TargetVocab, unsigned &Dim);
- void generateNumMappedVocab();
+ void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
+ VocabMap &ArgVocab);
void emitError(Error Err, LLVMContext &Ctx);
public:
LLVM_ABI static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
- LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
- LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
+ LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
+ : Vocab(std::move(Vocab)) {}
using Result = ir2vec::Vocabulary;
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51f0898cb37e..eeba109eb7dbd 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Module.h"
@@ -261,55 +262,121 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}
+// ==----------------------------------------------------------------------===//
+// VocabStorage
+//===----------------------------------------------------------------------===//
+
+VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
+ : Sections(std::move(SectionData)) {
+ TotalSize = 0;
+ Dimension = 0;
+ assert(!Sections.empty() && "Vocabulary has no sections");
+ assert(!Sections[0].empty() && "First section of vocabulary is empty");
+
+ // Compute total size across all sections
+ for (const auto &Section : Sections)
+ TotalSize += Section.size();
+
+ // Get dimension from the first embedding in the first section - all
+ // embeddings must have the same dimension
+ Dimension = static_cast<unsigned>(Sections[0][0].size());
+}
+
+VocabStorage &VocabStorage::operator=(VocabStorage &&Other) {
+ if (this != &Other) {
+ Sections = std::move(Other.Sections);
+ TotalSize = Other.TotalSize;
+ Dimension = Other.Dimension;
+ Other.TotalSize = 0;
+ Other.Dimension = 0;
+ }
+ return *this;
+}
+
+const Embedding &VocabStorage::const_iterator::operator*() const {
+ assert(SectionId < Storage->Sections.size() && "Invalid section ID");
+ assert(LocalIndex < Storage->Sections[SectionId].size() &&
+ "Local index out of range");
+ return Storage->Sections[SectionId][LocalIndex];
+}
+
+VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() {
+ ++LocalIndex;
+ // Check if we need to move to the next section
+ while (SectionId < Storage->getNumSections() &&
+ LocalIndex >= Storage->Sections[SectionId].size()) {
+ LocalIndex = 0;
+ ++SectionId;
+ }
+ return *this;
+}
+
+bool VocabStorage::const_iterator::operator==(
+ const const_iterator &Other) const {
+ return Storage == Other.Storage && SectionId == Other.SectionId &&
+ LocalIndex == Other.LocalIndex;
+}
+
+bool VocabStorage::const_iterator::operator!=(
+ const const_iterator &Other) const {
+ return !(*this == Other);
+}
+
// ==----------------------------------------------------------------------===//
// Vocabulary
//===----------------------------------------------------------------------===//
+Vocabulary &Vocabulary::operator=(Vocabulary &&Other) {
+ if (this != &Other)
+ Storage = std::move(Other.Storage);
+ return *this;
+}
+
unsigned Vocabulary::getDimension() const {
assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab[0].size();
+ return Storage.getDimension();
}
-unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
+unsigned Vocabulary::getIndex(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
return Opcode - 1; // Convert to zero-based index
}
-unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+unsigned Vocabulary::getIndex(Type::TypeID TypeID) {
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
}
-unsigned Vocabulary::getSlotIndex(const Value &Op) {
+unsigned Vocabulary::getIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
return OperandBaseOffset + Index;
}
-unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
- unsigned PU = static_cast<unsigned>(P);
- unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
- unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
-
- unsigned PredIdx =
- (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
- return PredicateBaseOffset + PredIdx;
+unsigned Vocabulary::getIndex(CmpInst::Predicate P) {
+ return PredicateBaseOffset + getPredicateLocalIndex(P);
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
- return Vocab[getSlotIndex(Opcode)];
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
}
const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
- return Vocab[getSlotIndex(TypeID)];
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
+ return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
}
const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
- return Vocab[getSlotIndex(Arg)];
+ unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
+ assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
+ return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
}
const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
- return Vocab[getSlotIndex(P)];
+ unsigned LocalIndex = getPredicateLocalIndex(P);
+ return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
}
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -359,12 +426,26 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
assert(Index < MaxPredicateKinds && "Invalid predicate index");
- unsigned PredEnumVal =
- (Index < NumFCmpPredicates)
- ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
- : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
- (Index - NumFCmpPredicates));
- return static_cast<CmpInst::Predicate>(PredEnumVal);
+ return getPredicateFromLocalIndex(Index);
+}
+
+unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
+ if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE)
+ return P - CmpInst::FIRST_FCMP_PREDICATE;
+ else
+ return P - CmpInst::FIRST_ICMP_PREDICATE +
+ (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1);
+}
+
+CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
+ unsigned fcmpRange =
+ CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1;
+ if (LocalIndex < fcmpRange)
+ return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE +
+ LocalIndex);
+ else
+ return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE +
+ LocalIndex - fcmpRange);
}
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
@@ -401,17 +482,51 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- DummyVocab.reserve(NumCanonicalEntries);
+VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, operands
- // and predicates
- for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
+
+ // Create sections for opcodes, types, operands, and predicates
+ // Order must match Vocabulary::Section enum
+ std::vector<std::vector<Embedding>> Sections;
+ Sections.reserve(4);
+
+ // Opcodes section
+ std::vector<Embedding> OpcodeSec;
+ OpcodeSec.reserve(MaxOpcodes);
+ for (unsigned I = 0; I < MaxOpcodes; ++I) {
+ OpcodeSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(OpcodeSec));
+
+ // Types section
+ std::vector<Embedding> TypeSec;
+ TypeSec.reserve(MaxCanonicalTypeIDs);
+ for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
+ TypeSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(TypeSec));
+
+ // Operands section
+ std::vector<Embedding> OperandSec;
+ OperandSec.reserve(MaxOperandKinds);
+ for (unsigned I = 0; I < MaxOperandKinds; ++I) {
+ OperandSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
- return DummyVocab;
+ Sections.push_back(std::move(OperandSec));
+
+ // Predicates section
+ std::vector<Embedding> PredicateSec;
+ PredicateSec.reserve(MaxPredicateKinds);
+ for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
+ PredicateSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(PredicateSec));
+
+ return VocabStorage(std::move(Sections));
}
// ==----------------------------------------------------------------------===//
@@ -457,7 +572,9 @@ Error IR2VecVocabAnalysis::parseVocabSection(
// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
-Error IR2VecVocabAnalysis::readVocabulary() {
+Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
+ VocabMap &TypeVocab,
+ VocabMap &ArgVocab) {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
@@ -488,7 +605,9 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return Error::success();
}
-void IR2VecVocabAnalysis::generateNumMappedVocab() {
+void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,
+ VocabMap &TypeVocab,
+ VocabMap &ArgVocab) {
// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -506,7 +625,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim));
- NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -515,13 +633,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
else
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
- NumericOpcodeEmbeddings.end());
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim));
- NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
@@ -531,13 +646,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
- NumericTypeEmbeddings.end());
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim));
- NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
@@ -548,14 +660,11 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
- NumericArgEmbeddings.end());
// Handle Predicates: part of Operands section. We look up predicate keys
// in ArgVocab.
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
Embedding(Dim, 0));
- NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
StringRef VocabKey =
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
@@ -566,15 +675,22 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
- NumericPredEmbeddings.end());
-}
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
- : Vocab(Vocab) {}
+ // Create section-based storage instead of flat vocabulary
+ // Order must match Vocabulary::Section enum
+ std::vector<std::vector<Embedding>> Sections(4);
+ Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =
+ std::move(NumericOpcodeEmbeddings); // Section::Opcodes
+ Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =
+ std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
+ Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =
+ std::move(NumericArgEmbeddings); // Section::Operands
+ Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =
+ std::move(NumericPredEmbeddings); // Section::Predicates
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
- : Vocab(std::move(Vocab)) {}
+ // Create VocabStorage from organized sections
+ Vocab.emplace(std::move(Sections));
+}
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
@@ -586,8 +702,8 @@ IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// If vocabulary is already populated by the constructor, use it.
- if (!Vocab.empty())
- return Vocabulary(std::move(Vocab));
+ if (Vocab.has_value())
+ return Vocabulary(std::move(Vocab.value()));
// Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
@@ -596,7 +712,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
"set it using --ir2vec-vocab-path");
return Vocabulary(); // Return invalid result
}
- if (auto Err = readVocabulary()) {
+
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
+ if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
emitError(std::move(Err), *Ctx);
return Vocabulary();
}
@@ -611,9 +729,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
scaleVocabSection(ArgVocab, ArgWeight);
// Generate the numeric lookup vocabulary
- generateNumMappedVocab();
+ generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);
- return Vocabulary(std::move(Vocab));
+ return Vocabulary(std::move(Vocab.value()));
}
// ==----------------------------------------------------------------------===//
@@ -622,7 +740,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
- auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
@@ -664,7 +782,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
- auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
// Print each entry
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 28b14c2562df1..0fa804f2959e8 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -217,7 +217,7 @@ AnalysisKey PluginInlineAdvisorAnalysis::Key;
bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
Module &M, ModuleAnalysisManager &MAM) {
if (!IR2VecVocabFile.empty()) {
- auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
if (!IR2VecVocabResult.isValid()) {
M.getContext().emitError("Failed to load IR2Vec vocabulary");
return false;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1c656b8fcf4e7..434449c7c5117 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
+ unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
+ unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index dc6059dcf6827..442f703f08d0c 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -43,8 +43,10 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
public:
FunctionPropertiesAnalysisTest() {
auto VocabVector = ir2vec::Vocabulary::createDummyVocabForTest(1);
- MAM.registerPass([&] { return IR2VecVocabAnalysis(VocabVector); });
- IR2VecVocab = ir2vec::Vocabulary(std::move(VocabVector));
+ MAM.registerPass([VocabVector = std::move(VocabVector)]() mutable {
+ return IR2VecVocabAnalysis(std::move(VocabVector));
+ });
+ IR2VecVocab = ir2vec::Vocabulary(ir2vec::Vocabulary::createDummyVocabForTest(1));
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
@@ -78,7 +80,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
FunctionPropertiesInfo buildFPI(Function &F) {
// FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to
// use IR2Vec.
- auto VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
+ auto &VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
(void)VocabResult;
return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM);
}
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9bc48e45eab5e..d915920eccda0 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -464,7 +464,10 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
MaxPredicateKinds);
- auto ExpectedVocab = VocabVec;
+ // Collect embeddings for later comparison before moving VocabVec
+ std::vector<Embedding> ExpectedVocab;
+ for (const auto &Emb : VocabVec)
+ ExpectedVocab.push_back(Emb);
IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec));
LLVMContext TestCtx;
@@ -482,17 +485,17 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
TEST(IR2VecVocabularyTest, SlotIdxMapping) {
- // Test getSlotIndex for Opcodes
+ // Test getIndex for Opcodes
#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+ EXPECT_EQ(Vocabulary::getIndex(NUM), static_cast<unsigned>(NUM - 1));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
#undef HANDLE_INST
#undef EXPECT_OPCODE_SLOT
- // Test getSlotIndex for Types
+ // Test getIndex for Types
#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \
+ EXPECT_EQ(Vocabulary::getIndex(Type::TypeIDTok), \
MaxOpcodes + static_cast<unsigned>( \
Vocabulary::CanonicalTypeID::CanonEnum));
@@ -500,7 +503,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#undef EXPECT_TYPE_SLOT
- // Test getSlotIndex for Value operands
+ // Test getIndex for Value operands
LLVMContext Ctx;
Module M("TestM", Ctx);
FunctionType *FTy =
@@ -510,27 +513,27 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
- EXPECT_EQ(Vocabulary::getSlotIndex(*F),
+ EXPECT_EQ(Vocabulary::getIndex(*F),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- EXPECT_EQ(Vocabulary::getSlotIndex(*C),
+ EXPECT_EQ(Vocabulary::getIndex(*C),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
- EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal),
+ EXPECT_EQ(Vocabulary::getIndex(*PtrVal),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
- EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
+ EXPECT_EQ(Vocabulary::getIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
- // Test getSlotIndex for predicates
+ // Test getIndex for predicates
#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
@@ -538,7 +541,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
unsigned ExpectedIdx =
EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
- EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
}
auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
@@ -546,7 +549,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
- EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
}
#undef EXPECTED_VOCAB_PREDICATE_SLOT
}
@@ -555,15 +558,14 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
// Test invalid opcode IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
- EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(MaxOpcodes + 1), "Invalid opcode");
// Test invalid type IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
"Invalid type ID");
- EXPECT_DEATH(
- Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
- "Invalid type ID");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
@@ -573,7 +575,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \
+ EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getIndex(NUM)), \
Vocabulary::getVocabKeyForOpcode(NUM));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
@@ -672,10 +674,12 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ( \
- Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \
- CanonStr);
+ do { \
+ unsigned FlatIdx = V.getIndex(Type::TypeIDTok); \
+ EXPECT_EQ(Vocabulary::getStringKey(FlatIdx), CanonStr); \
+ } while (0);
IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
@@ -683,14 +687,20 @@ TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
- std::vector<Embedding> InvalidVocab;
- InvalidVocab.push_back(Embedding(2, 1.0));
- InvalidVocab.push_back(Embedding(2, 2.0));
-
- Vocabulary V(std::move(InvalidVocab));
+ // Test 1: Create invalid VocabStorage with insufficient sections
+ std::vector<std::vector<Embedding>> InvalidSectionData;
+ // Only add one section with 2 embeddings, but the vocabulary needs 4 sections
+ std::vector<Embedding> Section1;
+ Section1.push_back(Embedding(2, 1.0));
+ Section1.push_back(Embedding(2, 2.0));
+ InvalidSectionData.push_back(std::move(Section1));
+
+ VocabStorage InvalidStorage(std::move(InvalidSectionData));
+ Vocabulary V(std::move(InvalidStorage));
EXPECT_FALSE(V.isValid());
{
+ // Test 2: Default-constructed vocabulary should be invalid
Vocabulary InvalidResult;
EXPECT_FALSE(InvalidResult.isValid());
#if GTEST_HAS_DEATH_TEST
@@ -701,4 +711,239 @@ TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
}
}
+TEST(VocabStorageTest, DefaultConstructor) {
+ VocabStorage storage;
+
+ EXPECT_EQ(storage.size(), 0u);
+ EXPECT_EQ(storage.getNumSections(), 0u);
+ EXPECT_EQ(storage.getDimension(), 0u);
+ EXPECT_FALSE(storage.isValid());
+
+ // Test iterators on empty storage
+ EXPECT_EQ(storage.begin(), storage.end());
+}
+
+TEST(VocabStorageTest, BasicConstruction) {
+ // Create test data with 3 sections
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings of dimension 3
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0, 3.0});
+ section0.emplace_back(std::vector<double>{4.0, 5.0, 6.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: 1 embedding of dimension 3
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{7.0, 8.0, 9.0});
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings of dimension 3
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{10.0, 11.0, 12.0});
+ section2.emplace_back(std::vector<double>{13.0, 14.0, 15.0});
+ section2.emplace_back(std::vector<double>{16.0, 17.0, 18.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_EQ(storage.size(), 6u); // Total: 2 + 1 + 3 = 6
+ EXPECT_EQ(storage.getNumSections(), 3u);
+ EXPECT_EQ(storage.getDimension(), 3u);
+ EXPECT_TRUE(storage.isValid());
+}
+
+TEST(VocabStorageTest, SectionAccess) {
+ // Create test data
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test section access
+ EXPECT_EQ(storage[0].size(), 2u);
+ EXPECT_EQ(storage[1].size(), 1u);
+
+ // Test embedding values
+ EXPECT_THAT(storage[0][0].getData(), ElementsAre(1.0, 2.0));
+ EXPECT_THAT(storage[0][1].getData(), ElementsAre(3.0, 4.0));
+ EXPECT_THAT(storage[1][0].getData(), ElementsAre(5.0, 6.0));
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(VocabStorageTest, InvalidSectionAccess) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_DEATH(storage[1], "Invalid section ID");
+ EXPECT_DEATH(storage[10], "Invalid section ID");
+}
+
+TEST(VocabStorageTest, EmptySection) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> emptySection; // Empty section
+ sectionData.push_back(std::move(emptySection));
+
+ std::vector<Embedding> validSection;
+ validSection.emplace_back(std::vector<double>{1.0});
+ sectionData.push_back(std::move(validSection));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "First section of vocabulary is empty");
+}
+
+TEST(VocabStorageTest, NoSections) {
+ std::vector<std::vector<Embedding>> sectionData; // No sections
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary has no sections");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
+TEST(VocabStorageTest, MoveAssignment) {
+ // Create source storage
+ std::vector<std::vector<Embedding>> sectionData1;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ sectionData1.push_back(std::move(section0));
+ VocabStorage source(std::move(sectionData1));
+
+ // Create destination storage
+ std::vector<std::vector<Embedding>> sectionData2;
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0, 7.0});
+ sectionData2.push_back(std::move(section1));
+ VocabStorage dest(std::move(sectionData2));
+
+ EXPECT_EQ(dest.getDimension(), 3u); // Initially 3D
+
+ // Move assign
+ dest = std::move(source);
+
+ // Check destination has source's data
+ EXPECT_EQ(dest.size(), 1u);
+ EXPECT_EQ(dest.getDimension(), 2u); // Now 2D from source
+ EXPECT_TRUE(dest.isValid());
+ EXPECT_THAT(dest[0][0].getData(), ElementsAre(1.0, 2.0));
+}
+
+TEST(VocabStorageTest, IteratorBasics) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test iterator basics
+ auto it = storage.begin();
+ auto end = storage.end();
+
+ EXPECT_NE(it, end);
+
+ // Check first embedding
+ EXPECT_THAT((*it).getData(), ElementsAre(1.0, 2.0));
+
+ // Advance to second embedding
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(3.0, 4.0));
+
+ // Advance to third embedding (in section 1)
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(5.0, 6.0));
+
+ // Advance past the end
+ ++it;
+ EXPECT_EQ(it, end);
+}
+
+TEST(VocabStorageTest, IteratorTraversal) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{10.0});
+ section0.emplace_back(std::vector<double>{20.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: empty section (to test section skipping)
+ std::vector<Embedding> section1; // Empty
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{30.0});
+ section2.emplace_back(std::vector<double>{40.0});
+ section2.emplace_back(std::vector<double>{50.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Collect all values using iterator
+ std::vector<double> values;
+ for (const auto &emb : storage) {
+ EXPECT_EQ(emb.size(), 1u);
+ values.push_back(emb[0]);
+ }
+
+ // Should get all embeddings from non-empty sections
+ EXPECT_THAT(values, ElementsAre(10.0, 20.0, 30.0, 40.0, 50.0));
+}
+
+TEST(VocabStorageTest, IteratorComparison) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0});
+ section0.emplace_back(std::vector<double>{2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ auto it1 = storage.begin();
+ auto it2 = storage.begin();
+ auto end = storage.end();
+
+ // Test equality
+ EXPECT_EQ(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance one iterator
+ ++it1;
+ EXPECT_NE(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance second iterator to match
+ ++it2;
+ EXPECT_EQ(it1, it2);
+
+ // Advance both to end
+ ++it1;
+ ++it2;
+ EXPECT_EQ(it1, end);
+ EXPECT_EQ(it2, end);
+ EXPECT_EQ(it1, it2);
+}
+
} // end anonymous namespace
More information about the llvm-branch-commits
mailing list