[llvm] ed1d954 - [IR2Vec] Refactor vocabulary to use section-based storage (#158376)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 1 17:13:17 PDT 2025
Author: S. VenkataKeerthy
Date: 2025-10-01T17:13:13-07:00
New Revision: ed1d9548b5c08142dab82bcfdd9875177d8223a5
URL: https://github.com/llvm/llvm-project/commit/ed1d9548b5c08142dab82bcfdd9875177d8223a5
DIFF: https://github.com/llvm/llvm-project/commit/ed1d9548b5c08142dab82bcfdd9875177d8223a5.diff
LOG: [IR2Vec] Refactor vocabulary to use section-based storage (#158376)
Refactored IR2Vec vocabulary and introduced IR (semantics) agnostic `VocabStorage`
- `Vocabulary` *has-a* `VocabStorage`
- `Vocabulary` deals with LLVM IR specific entities. This would help in efficient reuse of parts of the logic for MIR.
- Storage uses a section-based approach instead of a flat vector, improving organization and access patterns.
Added:
Modified:
llvm/include/llvm/Analysis/IR2Vec.h
llvm/lib/Analysis/IR2Vec.cpp
llvm/lib/Analysis/InlineAdvisor.cpp
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
llvm/unittests/Analysis/IR2VecTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f3f9de460218b..b7c301580a8a4 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -45,6 +45,7 @@
#include "llvm/Support/JSON.h"
#include <array>
#include <map>
+#include <optional>
namespace llvm {
@@ -144,6 +145,73 @@ struct Embedding {
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
+/// Generic storage class for section-based vocabularies.
+/// VocabStorage provides a generic foundation for storing and accessing
+/// embeddings organized into sections.
+class VocabStorage {
+private:
+ /// Section-based storage
+ std::vector<std::vector<Embedding>> Sections;
+
+ const size_t TotalSize;
+ const unsigned Dimension;
+
+public:
+ /// Default constructor creates empty storage (invalid state)
+ VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
+
+ /// Create a VocabStorage with pre-organized section data
+ VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
+
+ VocabStorage(VocabStorage &&) = default;
+ VocabStorage &operator=(VocabStorage &&) = delete;
+
+ VocabStorage(const VocabStorage &) = delete;
+ VocabStorage &operator=(const VocabStorage &) = delete;
+
+ /// Get total number of entries across all sections
+ size_t size() const { return TotalSize; }
+
+ /// Get number of sections
+ unsigned getNumSections() const {
+ return static_cast<unsigned>(Sections.size());
+ }
+
+ /// Section-based access: Storage[sectionId][localIndex]
+ const std::vector<Embedding> &operator[](unsigned SectionId) const {
+ assert(SectionId < Sections.size() && "Invalid section ID");
+ return Sections[SectionId];
+ }
+
+ /// Get vocabulary dimension
+ unsigned getDimension() const { return Dimension; }
+
+ /// Check if vocabulary is valid (has data)
+ bool isValid() const { return TotalSize > 0; }
+
+ /// Iterator support for section-based access
+ class const_iterator {
+ const VocabStorage *Storage;
+ unsigned SectionId = 0;
+ size_t LocalIndex = 0;
+
+ public:
+ const_iterator(const VocabStorage *Storage, unsigned SectionId,
+ size_t LocalIndex)
+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
+
+ LLVM_ABI const Embedding &operator*() const;
+ LLVM_ABI const_iterator &operator++();
+ LLVM_ABI bool operator==(const const_iterator &Other) const;
+ LLVM_ABI bool operator!=(const const_iterator &Other) const;
+ };
+
+ const_iterator begin() const { return const_iterator(this, 0, 0); }
+ const_iterator end() const {
+ return const_iterator(this, getNumSections(), 0);
+ }
+};
+
/// Class for storing and accessing the IR2Vec vocabulary.
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
/// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
- // Vocabulary Slot Layout:
+ // Vocabulary Layout:
// +----------------+------------------------------------------------------+
// | Entity Type | Index Range |
// +----------------+------------------------------------------------------+
@@ -180,8 +248,16 @@ class Vocabulary {
// and improves learning. Operands include Comparison predicates
// (ICmp/FCmp) along with other operand types. This can be extended to
// include other specializations in future.
- using VocabVector = std::vector<ir2vec::Embedding>;
- VocabVector Vocab;
+ enum class Section : unsigned {
+ Opcodes = 0,
+ CanonicalTypes = 1,
+ Operands = 2,
+ Predicates = 3,
+ MaxSections
+ };
+
+ // Use section-based storage for better organization and efficiency
+ VocabStorage Storage;
static constexpr unsigned NumICmpPredicates =
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -233,10 +309,23 @@ class Vocabulary {
NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
- LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
+ LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+
+ Vocabulary(const Vocabulary &) = delete;
+ Vocabulary &operator=(const Vocabulary &) = delete;
+
+ Vocabulary(Vocabulary &&) = default;
+ Vocabulary &operator=(Vocabulary &&Other) = delete;
+
+ LLVM_ABI bool isValid() const {
+ return Storage.size() == NumCanonicalEntries;
+ }
+
+ LLVM_ABI unsigned getDimension() const {
+ assert(isValid() && "IR2Vec Vocabulary is invalid");
+ return Storage.getDimension();
+ }
- LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
- LLVM_ABI unsigned getDimension() const;
/// Total number of entries (opcodes + canonicalized types + operand kinds +
/// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
@@ -245,10 +334,16 @@ class Vocabulary {
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
/// Function to get vocabulary key for a given TypeID
- LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+ LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID) {
+ return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
+ }
/// Function to get vocabulary key for a given OperandKind
- LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
+ LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind) {
+ unsigned Index = static_cast<unsigned>(Kind);
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return OperandKindNames[Index];
+ }
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
@@ -256,40 +351,66 @@ class Vocabulary {
/// Function to get vocabulary key for a given predicate
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
- /// Functions to return the slot index or position of a given Opcode, TypeID,
- /// or OperandKind in the vocabulary.
- LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
- LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
- LLVM_ABI static unsigned getSlotIndex(const Value &Op);
- LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
+ /// Functions to return flat index
+ LLVM_ABI static unsigned getIndex(unsigned Opcode) {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Opcode - 1; // Convert to zero-based index
+ }
+
+ LLVM_ABI static unsigned getIndex(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+ }
+
+ LLVM_ABI static unsigned getIndex(const Value &Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return OperandBaseOffset + Index;
+ }
+
+ LLVM_ABI static unsigned getIndex(CmpInst::Predicate P) {
+ return PredicateBaseOffset + getPredicateLocalIndex(P);
+ }
/// Accessors to get the embedding for a given entity.
- LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
- LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
- LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
- LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
+ }
+
+ LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeID) const {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
+ return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
+ }
+
+ LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const {
+ unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
+ assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
+ return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
+ }
+
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const {
+ unsigned LocalIndex = getPredicateLocalIndex(P);
+ return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
+ }
/// Const Iterator type aliases
- using const_iterator = VocabVector::const_iterator;
+ using const_iterator = VocabStorage::const_iterator;
+
const_iterator begin() const {
assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.begin();
+ return Storage.begin();
}
- const_iterator cbegin() const {
- assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.cbegin();
- }
+ const_iterator cbegin() const { return begin(); }
const_iterator end() const {
assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.end();
+ return Storage.end();
}
- const_iterator cend() const {
- assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab.cend();
- }
+ const_iterator cend() const { return end(); }
/// Returns the string key for a given index position in the vocabulary.
/// This is useful for debugging or printing the vocabulary. Do not use this
@@ -297,7 +418,7 @@ class Vocabulary {
LLVM_ABI static StringRef getStringKey(unsigned Pos);
/// Create a dummy vocabulary for testing purposes.
- LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
+ LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
@@ -306,12 +427,16 @@ class Vocabulary {
constexpr static unsigned NumCanonicalEntries =
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
- // Base offsets for slot layout to simplify index computation
+ // Base offsets for flat index computation
constexpr static unsigned OperandBaseOffset =
MaxOpcodes + MaxCanonicalTypeIDs;
constexpr static unsigned PredicateBaseOffset =
OperandBaseOffset + MaxOperandKinds;
+ /// Functions for predicate index calculations
+ static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
+ static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
+
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
@@ -358,15 +483,26 @@ class Vocabulary {
/// Function to get vocabulary key for canonical type by enum
LLVM_ABI static StringRef
- getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+ getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+ unsigned Index = static_cast<unsigned>(CType);
+ assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+ return CanonicalTypeNames[Index];
+ }
/// Function to convert TypeID to CanonicalTypeID
- LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID) {
+ unsigned Index = static_cast<unsigned>(TypeID);
+ assert(Index < MaxTypeIDs && "Invalid TypeID");
+ return TypeIDMapping[Index];
+ }
/// Function to get the predicate enum value for a given index. Index is
/// relative to the predicates section of the vocabulary. E.g., Index 0
/// corresponds to the first predicate.
- LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ return getPredicateFromLocalIndex(Index);
+ }
};
/// Embedder provides the interface to generate embeddings (vector
@@ -459,22 +595,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
/// its corresponding embedding.
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
- using VocabVector = std::vector<ir2vec::Embedding>;
using VocabMap = std::map<std::string, ir2vec::Embedding>;
- VocabMap OpcVocab, TypeVocab, ArgVocab;
- VocabVector Vocab;
+ std::optional<ir2vec::VocabStorage> Vocab;
- Error readVocabulary();
+ Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
+ VocabMap &ArgVocab);
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
VocabMap &TargetVocab, unsigned &Dim);
- void generateNumMappedVocab();
+ void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
+ VocabMap &ArgVocab);
void emitError(Error Err, LLVMContext &Ctx);
public:
LLVM_ABI static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
- LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
- LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
+ LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
+ : Vocab(std::move(Vocab)) {}
using Result = ir2vec::Vocabulary;
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51f0898cb37e..271f004b0a787 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Module.h"
@@ -262,55 +263,75 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
}
// ==----------------------------------------------------------------------===//
-// Vocabulary
+// VocabStorage
//===----------------------------------------------------------------------===//
-unsigned Vocabulary::getDimension() const {
- assert(isValid() && "IR2Vec Vocabulary is invalid");
- return Vocab[0].size();
-}
-
-unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
- assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
- return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
-}
-
-unsigned Vocabulary::getSlotIndex(const Value &Op) {
- unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return OperandBaseOffset + Index;
-}
-
-unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
- unsigned PU = static_cast<unsigned>(P);
- unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
- unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
-
- unsigned PredIdx =
- (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
- return PredicateBaseOffset + PredIdx;
-}
-
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
- return Vocab[getSlotIndex(Opcode)];
+VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
+ : Sections(std::move(SectionData)), TotalSize([&] {
+ assert(!Sections.empty() && "Vocabulary has no sections");
+ // Compute total size across all sections
+ size_t Size = 0;
+ for (const auto &Section : Sections) {
+ assert(!Section.empty() && "Vocabulary section is empty");
+ Size += Section.size();
+ }
+ return Size;
+ }()),
+ Dimension([&] {
+ // Get dimension from the first embedding in the first section - all
+ // embeddings must have the same dimension
+ assert(!Sections.empty() && "Vocabulary has no sections");
+ assert(!Sections[0].empty() && "First section of vocabulary is empty");
+ unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
+
+ // Verify that all embeddings across all sections have the same
+ // dimension
+ auto allSameDim = [ExpectedDim](const std::vector<Embedding> &Section) {
+ return std::all_of(Section.begin(), Section.end(),
+ [ExpectedDim](const Embedding &Emb) {
+ return Emb.size() == ExpectedDim;
+ });
+ };
+ assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
+ "All embeddings must have the same dimension");
+
+ return ExpectedDim;
+ }()) {}
+
+const Embedding &VocabStorage::const_iterator::operator*() const {
+ assert(SectionId < Storage->Sections.size() && "Invalid section ID");
+ assert(LocalIndex < Storage->Sections[SectionId].size() &&
+ "Local index out of range");
+ return Storage->Sections[SectionId][LocalIndex];
+}
+
+VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() {
+ ++LocalIndex;
+ // Check if we need to move to the next section
+ if (SectionId < Storage->getNumSections() &&
+ LocalIndex >= Storage->Sections[SectionId].size()) {
+ assert(LocalIndex == Storage->Sections[SectionId].size() &&
+ "Local index should be at the end of the current section");
+ LocalIndex = 0;
+ ++SectionId;
+ }
+ return *this;
}
-const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
- return Vocab[getSlotIndex(TypeID)];
+bool VocabStorage::const_iterator::operator==(
+ const const_iterator &Other) const {
+ return Storage == Other.Storage && SectionId == Other.SectionId &&
+ LocalIndex == Other.LocalIndex;
}
-const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
- return Vocab[getSlotIndex(Arg)];
+bool VocabStorage::const_iterator::operator!=(
+ const const_iterator &Other) const {
+ return !(*this == Other);
}
-const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
- return Vocab[getSlotIndex(P)];
-}
+// ==----------------------------------------------------------------------===//
+// Vocabulary
+//===----------------------------------------------------------------------===//
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
@@ -323,29 +344,6 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
return "UnknownOpcode";
}
-StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
- unsigned Index = static_cast<unsigned>(CType);
- assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
- return CanonicalTypeNames[Index];
-}
-
-Vocabulary::CanonicalTypeID
-Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
- unsigned Index = static_cast<unsigned>(TypeID);
- assert(Index < MaxTypeIDs && "Invalid TypeID");
- return TypeIDMapping[Index];
-}
-
-StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
- return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
-}
-
-StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
- unsigned Index = static_cast<unsigned>(Kind);
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return OperandKindNames[Index];
-}
-
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
@@ -357,14 +355,23 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
-CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
- assert(Index < MaxPredicateKinds && "Invalid predicate index");
- unsigned PredEnumVal =
- (Index < NumFCmpPredicates)
- ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
- : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
- (Index - NumFCmpPredicates));
- return static_cast<CmpInst::Predicate>(PredEnumVal);
+unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
+ if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE)
+ return P - CmpInst::FIRST_FCMP_PREDICATE;
+ else
+ return P - CmpInst::FIRST_ICMP_PREDICATE +
+ (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1);
+}
+
+CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
+ unsigned fcmpRange =
+ CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1;
+ if (LocalIndex < fcmpRange)
+ return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE +
+ LocalIndex);
+ else
+ return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE +
+ LocalIndex - fcmpRange);
}
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
@@ -401,17 +408,51 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- DummyVocab.reserve(NumCanonicalEntries);
+VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, operands
- // and predicates
- for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
+
+ // Create sections for opcodes, types, operands, and predicates
+ // Order must match Vocabulary::Section enum
+ std::vector<std::vector<Embedding>> Sections;
+ Sections.reserve(4);
+
+ // Opcodes section
+ std::vector<Embedding> OpcodeSec;
+ OpcodeSec.reserve(MaxOpcodes);
+ for (unsigned I = 0; I < MaxOpcodes; ++I) {
+ OpcodeSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
- return DummyVocab;
+ Sections.push_back(std::move(OpcodeSec));
+
+ // Types section
+ std::vector<Embedding> TypeSec;
+ TypeSec.reserve(MaxCanonicalTypeIDs);
+ for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
+ TypeSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(TypeSec));
+
+ // Operands section
+ std::vector<Embedding> OperandSec;
+ OperandSec.reserve(MaxOperandKinds);
+ for (unsigned I = 0; I < MaxOperandKinds; ++I) {
+ OperandSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(OperandSec));
+
+ // Predicates section
+ std::vector<Embedding> PredicateSec;
+ PredicateSec.reserve(MaxPredicateKinds);
+ for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
+ PredicateSec.emplace_back(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ Sections.push_back(std::move(PredicateSec));
+
+ return VocabStorage(std::move(Sections));
}
// ==----------------------------------------------------------------------===//
@@ -457,7 +498,9 @@ Error IR2VecVocabAnalysis::parseVocabSection(
// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
-Error IR2VecVocabAnalysis::readVocabulary() {
+Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
+ VocabMap &TypeVocab,
+ VocabMap &ArgVocab) {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
@@ -488,7 +531,9 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return Error::success();
}
-void IR2VecVocabAnalysis::generateNumMappedVocab() {
+void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,
+ VocabMap &TypeVocab,
+ VocabMap &ArgVocab) {
// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -506,7 +551,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim));
- NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -515,13 +559,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
else
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
- NumericOpcodeEmbeddings.end());
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim));
- NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
@@ -531,13 +572,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
- NumericTypeEmbeddings.end());
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim));
- NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
@@ -548,14 +586,11 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
- NumericArgEmbeddings.end());
// Handle Predicates: part of Operands section. We look up predicate keys
// in ArgVocab.
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
Embedding(Dim, 0));
- NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
StringRef VocabKey =
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
@@ -566,15 +601,22 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
handleMissingEntity(VocabKey.str());
}
- Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
- NumericPredEmbeddings.end());
-}
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
- : Vocab(Vocab) {}
+ // Create section-based storage instead of flat vocabulary
+ // Order must match Vocabulary::Section enum
+ std::vector<std::vector<Embedding>> Sections(4);
+ Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =
+ std::move(NumericOpcodeEmbeddings); // Section::Opcodes
+ Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =
+ std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
+ Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =
+ std::move(NumericArgEmbeddings); // Section::Operands
+ Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =
+ std::move(NumericPredEmbeddings); // Section::Predicates
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
- : Vocab(std::move(Vocab)) {}
+ // Create VocabStorage from organized sections
+ Vocab.emplace(std::move(Sections));
+}
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
@@ -586,8 +628,8 @@ IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// If vocabulary is already populated by the constructor, use it.
- if (!Vocab.empty())
- return Vocabulary(std::move(Vocab));
+ if (Vocab.has_value())
+ return Vocabulary(std::move(Vocab.value()));
// Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
@@ -596,7 +638,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
"set it using --ir2vec-vocab-path");
return Vocabulary(); // Return invalid result
}
- if (auto Err = readVocabulary()) {
+
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
+ if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
emitError(std::move(Err), *Ctx);
return Vocabulary();
}
@@ -611,9 +655,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
scaleVocabSection(ArgVocab, ArgWeight);
// Generate the numeric lookup vocabulary
- generateNumMappedVocab();
+ generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);
- return Vocabulary(std::move(Vocab));
+ return Vocabulary(std::move(Vocab.value()));
}
// ==----------------------------------------------------------------------===//
@@ -622,7 +666,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
- auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
@@ -664,7 +708,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
- auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
// Print each entry
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 28b14c2562df1..0fa804f2959e8 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -217,7 +217,7 @@ AnalysisKey PluginInlineAdvisorAnalysis::Key;
bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
Module &M, ModuleAnalysisManager &MAM) {
if (!IR2VecVocabFile.empty()) {
- auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+ auto &IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
if (!IR2VecVocabResult.isValid()) {
M.getContext().emitError("Failed to load IR2Vec vocabulary");
return false;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1c656b8fcf4e7..434449c7c5117 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
+ unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
+ unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index dc6059dcf6827..b6e8567ee514d 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -43,8 +43,11 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
public:
FunctionPropertiesAnalysisTest() {
auto VocabVector = ir2vec::Vocabulary::createDummyVocabForTest(1);
- MAM.registerPass([&] { return IR2VecVocabAnalysis(VocabVector); });
- IR2VecVocab = ir2vec::Vocabulary(std::move(VocabVector));
+ MAM.registerPass([VocabVector = std::move(VocabVector)]() mutable {
+ return IR2VecVocabAnalysis(std::move(VocabVector));
+ });
+ IR2VecVocab =
+ new ir2vec::Vocabulary(ir2vec::Vocabulary::createDummyVocabForTest(1));
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
@@ -66,7 +69,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
std::unique_ptr<LoopInfo> LI;
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
- ir2vec::Vocabulary IR2VecVocab;
+ ir2vec::Vocabulary *IR2VecVocab;
void TearDown() override {
// Restore original IR2Vec weights
@@ -78,7 +81,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
FunctionPropertiesInfo buildFPI(Function &F) {
// FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to
// use IR2Vec.
- auto VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
+ auto &VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
(void)VocabResult;
return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM);
}
@@ -106,7 +109,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
}
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
- auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, IR2VecVocab);
+ auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, *IR2VecVocab);
EXPECT_TRUE(static_cast<bool>(Emb));
return Emb;
}
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9bc48e45eab5e..743628fffac76 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -295,7 +295,7 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
// Fixture for IR2Vec tests requiring IR setup.
class IR2VecTestFixture : public ::testing::Test {
protected:
- Vocabulary V;
+ Vocabulary *V;
LLVMContext Ctx;
std::unique_ptr<Module> M;
Function *F = nullptr;
@@ -304,7 +304,7 @@ class IR2VecTestFixture : public ::testing::Test {
Instruction *RetInst = nullptr;
void SetUp() override {
- V = Vocabulary(Vocabulary::createDummyVocabForTest(2));
+ V = new Vocabulary(Vocabulary::createDummyVocabForTest(2));
// Setup IR
M = std::make_unique<Module>("TestM", Ctx);
@@ -322,7 +322,7 @@ class IR2VecTestFixture : public ::testing::Test {
};
TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -341,7 +341,7 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -358,7 +358,7 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -373,7 +373,7 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -388,7 +388,7 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -398,7 +398,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -408,7 +408,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
@@ -420,7 +420,7 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
@@ -464,7 +464,10 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
MaxPredicateKinds);
- auto ExpectedVocab = VocabVec;
+ // Collect embeddings for later comparison before moving VocabVec
+ std::vector<Embedding> ExpectedVocab;
+ for (const auto &Emb : VocabVec)
+ ExpectedVocab.push_back(Emb);
IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec));
LLVMContext TestCtx;
@@ -482,17 +485,17 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
TEST(IR2VecVocabularyTest, SlotIdxMapping) {
- // Test getSlotIndex for Opcodes
+ // Test getIndex for Opcodes
#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+ EXPECT_EQ(Vocabulary::getIndex(NUM), static_cast<unsigned>(NUM - 1));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
#undef HANDLE_INST
#undef EXPECT_OPCODE_SLOT
- // Test getSlotIndex for Types
+ // Test getIndex for Types
#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \
+ EXPECT_EQ(Vocabulary::getIndex(Type::TypeIDTok), \
MaxOpcodes + static_cast<unsigned>( \
Vocabulary::CanonicalTypeID::CanonEnum));
@@ -500,7 +503,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#undef EXPECT_TYPE_SLOT
- // Test getSlotIndex for Value operands
+ // Test getIndex for Value operands
LLVMContext Ctx;
Module M("TestM", Ctx);
FunctionType *FTy =
@@ -510,27 +513,27 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
- EXPECT_EQ(Vocabulary::getSlotIndex(*F),
+ EXPECT_EQ(Vocabulary::getIndex(*F),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- EXPECT_EQ(Vocabulary::getSlotIndex(*C),
+ EXPECT_EQ(Vocabulary::getIndex(*C),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
- EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal),
+ EXPECT_EQ(Vocabulary::getIndex(*PtrVal),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
- EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
+ EXPECT_EQ(Vocabulary::getIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
- // Test getSlotIndex for predicates
+ // Test getIndex for predicates
#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
@@ -538,7 +541,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
unsigned ExpectedIdx =
EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
- EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
}
auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
@@ -546,7 +549,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
- EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
}
#undef EXPECTED_VOCAB_PREDICATE_SLOT
}
@@ -555,15 +558,14 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
// Test invalid opcode IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
- EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(MaxOpcodes + 1), "Invalid opcode");
// Test invalid type IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
"Invalid type ID");
- EXPECT_DEATH(
- Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
- "Invalid type ID");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
@@ -573,7 +575,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \
+ EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getIndex(NUM)), \
Vocabulary::getVocabKeyForOpcode(NUM));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
@@ -672,10 +674,12 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ( \
- Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \
- CanonStr);
+ do { \
+ unsigned FlatIdx = V.getIndex(Type::TypeIDTok); \
+ EXPECT_EQ(Vocabulary::getStringKey(FlatIdx), CanonStr); \
+ } while (0);
IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
@@ -683,14 +687,20 @@ TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
- std::vector<Embedding> InvalidVocab;
- InvalidVocab.push_back(Embedding(2, 1.0));
- InvalidVocab.push_back(Embedding(2, 2.0));
-
- Vocabulary V(std::move(InvalidVocab));
+ // Test 1: Create invalid VocabStorage with insufficient sections
+ std::vector<std::vector<Embedding>> InvalidSectionData;
+ // Only add one section with 2 embeddings, but the vocabulary needs 4 sections
+ std::vector<Embedding> Section1;
+ Section1.push_back(Embedding(2, 1.0));
+ Section1.push_back(Embedding(2, 2.0));
+ InvalidSectionData.push_back(std::move(Section1));
+
+ VocabStorage InvalidStorage(std::move(InvalidSectionData));
+ Vocabulary V(std::move(InvalidStorage));
EXPECT_FALSE(V.isValid());
{
+ // Test 2: Default-constructed vocabulary should be invalid
Vocabulary InvalidResult;
EXPECT_FALSE(InvalidResult.isValid());
#if GTEST_HAS_DEATH_TEST
@@ -701,4 +711,265 @@ TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
}
}
+TEST(VocabStorageTest, DefaultConstructor) {
+ VocabStorage storage;
+
+ EXPECT_EQ(storage.size(), 0u);
+ EXPECT_EQ(storage.getNumSections(), 0u);
+ EXPECT_EQ(storage.getDimension(), 0u);
+ EXPECT_FALSE(storage.isValid());
+
+ // Test iterators on empty storage
+ EXPECT_EQ(storage.begin(), storage.end());
+}
+
+TEST(VocabStorageTest, BasicConstruction) {
+ // Create test data with 3 sections
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings of dimension 3
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0, 3.0});
+ section0.emplace_back(std::vector<double>{4.0, 5.0, 6.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: 1 embedding of dimension 3
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{7.0, 8.0, 9.0});
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings of dimension 3
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{10.0, 11.0, 12.0});
+ section2.emplace_back(std::vector<double>{13.0, 14.0, 15.0});
+ section2.emplace_back(std::vector<double>{16.0, 17.0, 18.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_EQ(storage.size(), 6u); // Total: 2 + 1 + 3 = 6
+ EXPECT_EQ(storage.getNumSections(), 3u);
+ EXPECT_EQ(storage.getDimension(), 3u);
+ EXPECT_TRUE(storage.isValid());
+}
+
+TEST(VocabStorageTest, SectionAccess) {
+ // Create test data
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test section access
+ EXPECT_EQ(storage[0].size(), 2u);
+ EXPECT_EQ(storage[1].size(), 1u);
+
+ // Test embedding values
+ EXPECT_THAT(storage[0][0].getData(), ElementsAre(1.0, 2.0));
+ EXPECT_THAT(storage[0][1].getData(), ElementsAre(3.0, 4.0));
+ EXPECT_THAT(storage[1][0].getData(), ElementsAre(5.0, 6.0));
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(VocabStorageTest, InvalidSectionAccess) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_DEATH(storage[1], "Invalid section ID");
+ EXPECT_DEATH(storage[10], "Invalid section ID");
+}
+
+TEST(VocabStorageTest, EmptySection) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> emptySection; // Empty section
+ sectionData.push_back(std::move(emptySection));
+
+ std::vector<Embedding> validSection;
+ validSection.emplace_back(std::vector<double>{1.0});
+ sectionData.push_back(std::move(validSection));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary section is empty");
+}
+
+TEST(VocabStorageTest, EmptyMiddleSection) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Valid first section
+ std::vector<Embedding> validSection1;
+ validSection1.emplace_back(std::vector<double>{1.0});
+ sectionData.push_back(std::move(validSection1));
+
+ // Empty middle section
+ std::vector<Embedding> emptySection;
+ sectionData.push_back(std::move(emptySection));
+
+ // Valid last section
+ std::vector<Embedding> validSection2;
+ validSection2.emplace_back(std::vector<double>{2.0});
+ sectionData.push_back(std::move(validSection2));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary section is empty");
+}
+
+TEST(VocabStorageTest, NoSections) {
+ std::vector<std::vector<Embedding>> sectionData; // No sections
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary has no sections");
+}
+
+TEST(VocabStorageTest, MismatchedDimensionsAcrossSections) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: embeddings with dimension 2
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: embedding with dimension 3 (mismatch!)
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0, 7.0});
+ sectionData.push_back(std::move(section1));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "All embeddings must have the same dimension");
+}
+
+TEST(VocabStorageTest, MismatchedDimensionsWithinSection) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: first embedding with dimension 2, second with dimension 3
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0, 5.0}); // Mismatch!
+ sectionData.push_back(std::move(section0));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "All embeddings must have the same dimension");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
+TEST(VocabStorageTest, IteratorBasics) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test iterator basics
+ auto it = storage.begin();
+ auto end = storage.end();
+
+ EXPECT_NE(it, end);
+
+ // Check first embedding
+ EXPECT_THAT((*it).getData(), ElementsAre(1.0, 2.0));
+
+ // Advance to second embedding
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(3.0, 4.0));
+
+ // Advance to third embedding (in section 1)
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(5.0, 6.0));
+
+ // Advance past the end
+ ++it;
+ EXPECT_EQ(it, end);
+}
+
+TEST(VocabStorageTest, IteratorTraversal) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{10.0});
+ section0.emplace_back(std::vector<double>{20.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: 1 embedding
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{25.0});
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{30.0});
+ section2.emplace_back(std::vector<double>{40.0});
+ section2.emplace_back(std::vector<double>{50.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Collect all values using iterator
+ std::vector<double> values;
+ for (const auto &emb : storage) {
+ EXPECT_EQ(emb.size(), 1u);
+ values.push_back(emb[0]);
+ }
+
+ // Should get all embeddings from all sections
+ EXPECT_THAT(values, ElementsAre(10.0, 20.0, 25.0, 30.0, 40.0, 50.0));
+}
+
+TEST(VocabStorageTest, IteratorComparison) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0});
+ section0.emplace_back(std::vector<double>{2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ auto it1 = storage.begin();
+ auto it2 = storage.begin();
+ auto end = storage.end();
+
+ // Test equality
+ EXPECT_EQ(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance one iterator
+ ++it1;
+ EXPECT_NE(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance second iterator to match
+ ++it2;
+ EXPECT_EQ(it1, it2);
+
+ // Advance both to end
+ ++it1;
+ ++it2;
+ EXPECT_EQ(it1, end);
+ EXPECT_EQ(it2, end);
+ EXPECT_EQ(it1, it2);
+}
+
} // end anonymous namespace
More information about the llvm-commits
mailing list