[llvm] 45c5498 - [IR2Vec] Refactor vocabulary to use canonical type IDs (#155323)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 29 14:57:00 PDT 2025
Author: S. VenkataKeerthy
Date: 2025-08-29T14:56:56-07:00
New Revision: 45c549857383eade0c47db75951fb9441260653a
URL: https://github.com/llvm/llvm-project/commit/45c549857383eade0c47db75951fb9441260653a
DIFF: https://github.com/llvm/llvm-project/commit/45c549857383eade0c47db75951fb9441260653a.diff
LOG: [IR2Vec] Refactor vocabulary to use canonical type IDs (#155323)
Refactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types.
The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by:
1. Making the type representation more consistent by properly canonicalizing types
2. Reducing vocabulary size by eliminating redundant entries
3. Improving the embedding quality by ensuring similar types share the same representation
(Tracking issue - #141817)
Added:
Modified:
llvm/include/llvm/Analysis/IR2Vec.h
llvm/lib/Analysis/IR2Vec.cpp
llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
llvm/test/tools/llvm-ir2vec/entities.ll
llvm/test/tools/llvm-ir2vec/triplets.ll
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
llvm/unittests/Analysis/IR2VecTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index caa816e2fd76d..c42ca779e097c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
+#include <array>
#include <map>
namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// Class for storing and accessing the IR2Vec vocabulary.
-/// Encapsulates all vocabulary-related constants, logic, and access methods.
+/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
+/// seed embeddings are the initial learned representations of the entities
+/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
+/// seed embeddings.
+///
+/// The vocabulary contains the seed embeddings for three types of entities:
+/// instruction opcodes, types, and operands. Types are grouped/canonicalized
+/// for better learning (e.g., all float variants map to FloatTy). The
+/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
+/// handle all the known LLVM IR opcodes, types and operands.
+///
+/// This class helps populate the seed embeddings in an internal vector-based
+/// ADT. It provides logic to map every IR entity to a specific slot index or
+/// position in this vector, enabling O(1) embedding lookup while avoiding
+/// unnecessary computations involving string based lookups while generating the
+/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;
+public:
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction opcodes
+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+
+ /// Canonical type IDs supported by IR2Vec Vocabulary
+ enum class CanonicalTypeID : unsigned {
+ FloatTy,
+ VoidTy,
+ LabelTy,
+ MetadataTy,
+ VectorTy,
+ TokenTy,
+ IntegerTy,
+ FunctionTy,
+ PointerTy,
+ StructTy,
+ ArrayTy,
+ UnknownTy,
+ MaxCanonicalType
+ };
+
/// Operand kinds supported by IR2Vec Vocabulary
enum class OperandKind : unsigned {
FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
VariableID,
MaxOperandKind
};
- /// String mappings for OperandKind values
- static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
- "Constant", "Variable"};
- static_assert(std::size(OperandKindNames) ==
- static_cast<unsigned>(OperandKind::MaxOperandKind),
- "OperandKindNames array size must match MaxOperandKind");
-public:
/// Vocabulary layout constants
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
#include "llvm/IR/Instruction.def"
#undef LAST_OTHER_INST
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
+ static constexpr unsigned MaxCanonicalTypeIDs =
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
@@ -174,33 +208,31 @@ class Vocabulary {
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- LLVM_ABI size_t size() const;
+ /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
- static size_t expectedSize() {
- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
- }
-
- /// Helper function to get vocabulary key for a given Opcode
+ /// Function to get vocabulary key for a given Opcode
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
- /// Helper function to get vocabulary key for a given TypeID
+ /// Function to get vocabulary key for a given TypeID
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
- /// Helper function to get vocabulary key for a given OperandKind
+ /// Function to get vocabulary key for a given OperandKind
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
- /// Helper function to classify an operand into OperandKind
+ /// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
- /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
- LLVM_ABI static unsigned getNumericID(unsigned Opcode);
- LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
- LLVM_ABI static unsigned getNumericID(const Value *Op);
+ /// Functions to return the slot index or position of a given Opcode, TypeID,
+ /// or OperandKind in the vocabulary.
+ LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
+ LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
+ LLVM_ABI static unsigned getSlotIndex(const Value *Op);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
- LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
+
+private:
+ constexpr static unsigned NumCanonicalEntries =
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+
+ /// String mappings for CanonicalTypeID values
+ static constexpr StringLiteral CanonicalTypeNames[] = {
+ "FloatTy", "VoidTy", "LabelTy", "MetadataTy",
+ "VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
+ "PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
+ static_assert(std::size(CanonicalTypeNames) ==
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
+ "CanonicalTypeNames array size must match MaxCanonicalType");
+
+ /// String mappings for OperandKind values
+ static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
+ "Constant", "Variable"};
+ static_assert(std::size(OperandKindNames) ==
+ static_cast<unsigned>(OperandKind::MaxOperandKind),
+ "OperandKindNames array size must match MaxOperandKind");
+
+ /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
+ /// corresponding mapping here in the same order as enum Type::TypeID.
+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
+ CanonicalTypeID::FloatTy, // HalfTyID = 0
+ CanonicalTypeID::FloatTy, // BFloatTyID
+ CanonicalTypeID::FloatTy, // FloatTyID
+ CanonicalTypeID::FloatTy, // DoubleTyID
+ CanonicalTypeID::FloatTy, // X86_FP80TyID
+ CanonicalTypeID::FloatTy, // FP128TyID
+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
+ CanonicalTypeID::VoidTy, // VoidTyID
+ CanonicalTypeID::LabelTy, // LabelTyID
+ CanonicalTypeID::MetadataTy, // MetadataTyID
+ CanonicalTypeID::VectorTy, // X86_AMXTyID
+ CanonicalTypeID::TokenTy, // TokenTyID
+ CanonicalTypeID::IntegerTy, // IntegerTyID
+ CanonicalTypeID::FunctionTy, // FunctionTyID
+ CanonicalTypeID::PointerTy, // PointerTyID
+ CanonicalTypeID::StructTy, // StructTyID
+ CanonicalTypeID::ArrayTy, // ArrayTyID
+ CanonicalTypeID::VectorTy, // FixedVectorTyID
+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
+ CanonicalTypeID::PointerTy, // TypedPointerTyID
+ CanonicalTypeID::UnknownTy // TargetExtTyID
+ }};
+ static_assert(TypeIDMapping.size() == MaxTypeIDs,
+ "TypeIDMapping must cover all Type::TypeID values");
+
+ /// Function to get vocabulary key for canonical type by enum
+ LLVM_ABI static StringRef
+ getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+
+ /// Function to convert TypeID to CanonicalTypeID
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
};
/// Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
- /// Helper function to compute embeddings. It generates embeddings for all
+ /// Function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F.
void computeEmbeddings() const;
- /// Helper function to compute the embedding for a given basic block.
+ /// Function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index eb54f90a75488..886d22e424ea2 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,7 +32,7 @@ using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
- "Number of lookups to entites not present in the vocabulary");
+ "Number of lookups to entities not present in the vocabulary");
namespace llvm {
namespace ir2vec {
@@ -213,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
- ArgEmb += Vocab[Op];
+ ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
InstVecMap[&I] = InstVector;
@@ -242,8 +242,8 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// If the operand is not defined by an instruction, we use the vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
- << *Op << "=" << Vocab[Op][0] << "\n");
- ArgEmb += Vocab[Op];
+ << *Op << "=" << Vocab[*Op][0] << "\n");
+ ArgEmb += Vocab[*Op];
}
}
// Create the instruction vector by combining opcode, type, and arguments
@@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
: Vocab(std::move(Vocab)), Valid(true) {}
bool Vocabulary::isValid() const {
- return Vocab.size() == Vocabulary::expectedSize() && Valid;
-}
-
-size_t Vocabulary::size() const {
- assert(Valid && "IR2Vec Vocabulary is invalid");
- return Vocab.size();
+ return Vocab.size() == NumCanonicalEntries && Valid;
}
unsigned Vocabulary::getDimension() const {
@@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const {
return Vocab[0].size();
}
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Vocab[Opcode - 1];
+ return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+}
+
+unsigned Vocabulary::getSlotIndex(const Value *Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+ return Vocab[getSlotIndex(Opcode)];
}
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
- assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
- return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+ return Vocab[getSlotIndex(TypeID)];
}
-const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
- OperandKind ArgKind = getOperandKind(Arg);
- return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
+ return Vocab[getSlotIndex(&Arg)];
}
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
return "UnknownOpcode";
}
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+ unsigned Index = static_cast<unsigned>(CType);
+ assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+ return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+ unsigned Index = static_cast<unsigned>(TypeID);
+ assert(Index < MaxTypeIDs && "Invalid TypeID");
+ return TypeIDMapping[Index];
+}
+
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
- switch (TypeID) {
- case Type::VoidTyID:
- return "VoidTy";
- case Type::HalfTyID:
- case Type::BFloatTyID:
- case Type::FloatTyID:
- case Type::DoubleTyID:
- case Type::X86_FP80TyID:
- case Type::FP128TyID:
- case Type::PPC_FP128TyID:
- return "FloatTy";
- case Type::IntegerTyID:
- return "IntegerTy";
- case Type::FunctionTyID:
- return "FunctionTy";
- case Type::StructTyID:
- return "StructTy";
- case Type::ArrayTyID:
- return "ArrayTy";
- case Type::PointerTyID:
- case Type::TypedPointerTyID:
- return "PointerTy";
- case Type::FixedVectorTyID:
- case Type::ScalableVectorTyID:
- return "VectorTy";
- case Type::LabelTyID:
- return "LabelTy";
- case Type::TokenTyID:
- return "TokenTy";
- case Type::MetadataTyID:
- return "MetadataTy";
- case Type::X86_AMXTyID:
- case Type::TargetExtTyID:
- return "UnknownTy";
- }
- return "UnknownTy";
+ return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
}
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
return OperandKindNames[Index];
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operand
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
- Vocabulary::MaxOperandKinds)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
- DummyVal += 0.1f;
- }
- return DummyVocab;
-}
-
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
@@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
- assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
- return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
- unsigned Index = static_cast<unsigned>(getOperandKind(Op));
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxTypeIDs + Index;
-}
-
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < Vocabulary::expectedSize() &&
- "Position out of bounds in vocabulary");
+ assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxTypeIDs)
- return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+ if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ return getVocabKeyForCanonicalTypeID(
+ static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+ static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+ VocabVector DummyVocab;
+ DummyVocab.reserve(NumCanonicalEntries);
+ float DummyVal = 0.1f;
+ // Create a dummy vocabulary with entries for all opcodes, types, and
+ // operands
+ for ([[maybe_unused]] unsigned _ :
+ seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+ Vocabulary::MaxOperandKinds)) {
+ DummyVocab.push_back(Embedding(Dim, DummyVal));
+ DummyVal += 0.1f;
+ }
+ return DummyVocab;
+}
+
// ==----------------------------------------------------------------------===//
// IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
@@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim, 0));
+ NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
NumericOpcodeEmbeddings.end());
- // Handle Types
- std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
+ // Handle Types - only canonical types are present in vocabulary
+ std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim, 0));
- for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
- StringRef VocabKey =
- Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+ NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+ for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+ StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+ static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
- NumericTypeEmbeddings[TypeID] = It->second;
+ NumericTypeEmbeddings[CTypeID] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
@@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim, 0));
+ NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index 1b9b3c2acd8a5..df7769c9c6a65 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 129.00 130.00 ]
Key: LandingPad: [ 131.00 132.00 ]
Key: Freeze: [ 133.00 134.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index 9673e7f23fa5c..f3ce809fd2fd2 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 64.50 65.00 ]
Key: LandingPad: [ 65.50 66.00 ]
Key: Freeze: [ 66.50 67.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 1f575d29092dd..72b25b9bd3d9c 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 12.90 13.00 ]
Key: LandingPad: [ 13.10 13.20 ]
Key: Freeze: [ 13.30 13.40 ]
Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
Key: VoidTy: [ 0.00 0.00 ]
Key: LabelTy: [ 0.00 0.00 ]
Key: MetadataTy: [ 0.00 0.00 ]
-Key: UnknownTy: [ 0.00 0.00 ]
+Key: VectorTy: [ 0.00 0.00 ]
Key: TokenTy: [ 0.00 0.00 ]
Key: IntegerTy: [ 0.00 0.00 ]
Key: FunctionTy: [ 0.00 0.00 ]
Key: PointerTy: [ 0.00 0.00 ]
Key: StructTy: [ 0.00 0.00 ]
Key: ArrayTy: [ 0.00 0.00 ]
-Key: VectorTy: [ 0.00 0.00 ]
-Key: VectorTy: [ 0.00 0.00 ]
-Key: PointerTy: [ 0.00 0.00 ]
Key: UnknownTy: [ 0.00 0.00 ]
Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4ed6400d7a195..4b51adf30bf74 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 93
+CHECK: 84
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -70,27 +70,18 @@ CHECK-NEXT: InsertValue 65
CHECK-NEXT: LandingPad 66
CHECK-NEXT: Freeze 67
CHECK-NEXT: FloatTy 68
-CHECK-NEXT: FloatTy 69
-CHECK-NEXT: FloatTy 70
-CHECK-NEXT: FloatTy 71
-CHECK-NEXT: FloatTy 72
-CHECK-NEXT: FloatTy 73
-CHECK-NEXT: FloatTy 74
-CHECK-NEXT: VoidTy 75
-CHECK-NEXT: LabelTy 76
-CHECK-NEXT: MetadataTy 77
-CHECK-NEXT: UnknownTy 78
-CHECK-NEXT: TokenTy 79
-CHECK-NEXT: IntegerTy 80
-CHECK-NEXT: FunctionTy 81
-CHECK-NEXT: PointerTy 82
-CHECK-NEXT: StructTy 83
-CHECK-NEXT: ArrayTy 84
-CHECK-NEXT: VectorTy 85
-CHECK-NEXT: VectorTy 86
-CHECK-NEXT: PointerTy 87
-CHECK-NEXT: UnknownTy 88
-CHECK-NEXT: Function 89
-CHECK-NEXT: Pointer 90
-CHECK-NEXT: Constant 91
-CHECK-NEXT: Variable 92
+CHECK-NEXT: VoidTy 69
+CHECK-NEXT: LabelTy 70
+CHECK-NEXT: MetadataTy 71
+CHECK-NEXT: VectorTy 72
+CHECK-NEXT: TokenTy 73
+CHECK-NEXT: IntegerTy 74
+CHECK-NEXT: FunctionTy 75
+CHECK-NEXT: PointerTy 76
+CHECK-NEXT: StructTy 77
+CHECK-NEXT: ArrayTy 78
+CHECK-NEXT: UnknownTy 79
+CHECK-NEXT: Function 80
+CHECK-NEXT: Pointer 81
+CHECK-NEXT: Constant 82
+CHECK-NEXT: Variable 83
diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll
index 6f64bab888f6b..7b476f60a07b3 100644
--- a/llvm/test/tools/llvm-ir2vec/triplets.ll
+++ b/llvm/test/tools/llvm-ir2vec/triplets.ll
@@ -25,41 +25,41 @@ entry:
}
; TRIPLETS: MAX_RELATION=3
-; TRIPLETS-NEXT: 12 80 0
-; TRIPLETS-NEXT: 12 92 2
-; TRIPLETS-NEXT: 12 92 3
+; TRIPLETS-NEXT: 12 74 0
+; TRIPLETS-NEXT: 12 83 2
+; TRIPLETS-NEXT: 12 83 3
; TRIPLETS-NEXT: 12 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
-; TRIPLETS-NEXT: 16 80 0
-; TRIPLETS-NEXT: 16 92 2
-; TRIPLETS-NEXT: 16 92 3
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
+; TRIPLETS-NEXT: 16 74 0
+; TRIPLETS-NEXT: 16 83 2
+; TRIPLETS-NEXT: 16 83 3
; TRIPLETS-NEXT: 16 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
-; TRIPLETS-NEXT: 30 82 0
-; TRIPLETS-NEXT: 30 91 2
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
+; TRIPLETS-NEXT: 30 76 0
+; TRIPLETS-NEXT: 30 82 2
; TRIPLETS-NEXT: 30 30 1
-; TRIPLETS-NEXT: 30 82 0
-; TRIPLETS-NEXT: 30 91 2
+; TRIPLETS-NEXT: 30 76 0
+; TRIPLETS-NEXT: 30 82 2
; TRIPLETS-NEXT: 30 32 1
-; TRIPLETS-NEXT: 32 75 0
-; TRIPLETS-NEXT: 32 92 2
-; TRIPLETS-NEXT: 32 90 3
+; TRIPLETS-NEXT: 32 69 0
+; TRIPLETS-NEXT: 32 83 2
+; TRIPLETS-NEXT: 32 81 3
; TRIPLETS-NEXT: 32 32 1
-; TRIPLETS-NEXT: 32 75 0
-; TRIPLETS-NEXT: 32 92 2
-; TRIPLETS-NEXT: 32 90 3
+; TRIPLETS-NEXT: 32 69 0
+; TRIPLETS-NEXT: 32 83 2
+; TRIPLETS-NEXT: 32 81 3
; TRIPLETS-NEXT: 32 31 1
-; TRIPLETS-NEXT: 31 80 0
-; TRIPLETS-NEXT: 31 90 2
+; TRIPLETS-NEXT: 31 74 0
+; TRIPLETS-NEXT: 31 81 2
; TRIPLETS-NEXT: 31 31 1
-; TRIPLETS-NEXT: 31 80 0
-; TRIPLETS-NEXT: 31 90 2
+; TRIPLETS-NEXT: 31 74 0
+; TRIPLETS-NEXT: 31 81 2
; TRIPLETS-NEXT: 31 12 1
-; TRIPLETS-NEXT: 12 80 0
-; TRIPLETS-NEXT: 12 92 2
-; TRIPLETS-NEXT: 12 92 3
+; TRIPLETS-NEXT: 12 74 0
+; TRIPLETS-NEXT: 12 83 2
+; TRIPLETS-NEXT: 12 83 3
; TRIPLETS-NEXT: 12 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index c065aaeedd395..461ded77d9609 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getNumericID(I.getOpcode());
- unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID());
+ unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getNumericID(U.get());
+ unsigned OperandID = Vocabulary::getSlotIndex(U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
@@ -211,13 +211,7 @@ class IR2VecTool {
/// Dump entity ID to string mappings
static void generateEntityMappings(raw_ostream &OS) {
- // FIXME: Currently, the generated entity mappings are not one-to-one;
- // Multiple TypeIDs map to same string key (Like Half, BFloat, etc. map to
- // FloatTy). This would hinder learning good seed embeddings.
- // We should fix this in the future by ensuring unique string keys either by
- // post-processing here without changing the mapping in ir2vec::Vocabulary,
- // or by changing the Vocabulary generation logic to ensure unique keys.
- auto EntityLen = Vocabulary::expectedSize();
+ auto EntityLen = Vocabulary::getCanonicalSize();
OS << EntityLen << "\n";
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index f0c81e160ca15..9f5428758d64c 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -336,8 +336,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
EXPECT_EQ(AddEmb.size(), 2u);
EXPECT_EQ(RetEmb.size(), 2u);
- EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 27.9)));
- EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0)));
+ EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
}
TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
@@ -353,8 +353,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
- EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9)));
- EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6)));
+ EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
}
TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
@@ -367,9 +367,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);
- // BB vector should be sum of add and ret: {27.9, 27.9} + {17.0, 17.0} =
- // {44.9, 44.9}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9)));
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+ // {41.0, 41.0}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
@@ -382,9 +382,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);
- // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} =
- // {63.5, 63.5}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5)));
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+ // {58.1, 58.1}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,7 +394,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
- EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9)));
+ EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
@@ -404,7 +404,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
- EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5)));
+ EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
}
TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
@@ -415,8 +415,8 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
EXPECT_EQ(FuncVec.size(), 2u);
- // Function vector should match BB vector (only one BB): {44.9, 44.9}
- EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9)));
+ // Function vector should match BB vector (only one BB): {41.0, 41.0}
+ EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
@@ -426,24 +426,40 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
- // Function vector should match BB vector (only one BB): {63.5, 63.5}
- EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5)));
+ // Function vector should match BB vector (only one BB): {58.1, 58.1}
+ EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1)));
}
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
+static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
+// names and their canonical string keys.
+#define IR2VEC_HANDLE_TYPE_BIMAP(X) \
+ X(VoidTyID, VoidTy, "VoidTy") \
+ X(IntegerTyID, IntegerTy, "IntegerTy") \
+ X(FloatTyID, FloatTy, "FloatTy") \
+ X(PointerTyID, PointerTy, "PointerTy") \
+ X(FunctionTyID, FunctionTy, "FunctionTy") \
+ X(StructTyID, StructTy, "StructTy") \
+ X(ArrayTyID, ArrayTy, "ArrayTy") \
+ X(FixedVectorTyID, VectorTy, "VectorTy") \
+ X(LabelTyID, LabelTy, "LabelTy") \
+ X(TokenTyID, TokenTy, "TokenTy") \
+ X(MetadataTyID, MetadataTy, "MetadataTy")
+
TEST(IR2VecVocabularyTest, DummyVocabTest) {
for (unsigned Dim = 1; Dim <= 10; ++Dim) {
auto VocabVec = Vocabulary::createDummyVocabForTest(Dim);
-
+ auto VocabVecSize = VocabVec.size();
// All embeddings should have the same dimension
for (const auto &Emb : VocabVec)
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
- EXPECT_EQ(VocabVec.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+ EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
auto ExpectedVocab = VocabVec;
@@ -454,7 +470,7 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
Vocabulary Result = VocabAnalysis.run(TestMod, MAM);
EXPECT_TRUE(Result.isValid());
EXPECT_EQ(Result.getDimension(), Dim);
- EXPECT_EQ(Result.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+ EXPECT_EQ(Result.getCanonicalSize(), VocabVecSize);
unsigned CurPos = 0;
for (const auto &Entry : Result)
@@ -462,64 +478,68 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
}
-TEST(IR2VecVocabularyTest, NumericIDMap) {
- // Test getNumericID for opcodes
- EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
- EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
- EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
-
- // Test getNumericID for Type IDs
- EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
- MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
- MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
- MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
- MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
- MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
-
- // Test getNumericID for Value operands
+TEST(IR2VecVocabularyTest, SlotIdxMapping) {
+ // Test getSlotIndex for Opcodes
+#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \
+ EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE_SLOT
+
+ // Test getSlotIndex for Types
+#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \
+ MaxOpcodes + static_cast<unsigned>( \
+ Vocabulary::CanonicalTypeID::CanonEnum));
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_SLOT)
+
+#undef EXPECT_TYPE_SLOT
+
+ // Test getSlotIndex for Value operands
LLVMContext Ctx;
Module M("TestM", Ctx);
FunctionType *FTy =
FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
+ MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
- EXPECT_EQ(Vocabulary::getNumericID(F),
- MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+ EXPECT_EQ(Vocabulary::getSlotIndex(F),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- EXPECT_EQ(Vocabulary::getNumericID(C),
- MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+ EXPECT_EQ(Vocabulary::getSlotIndex(C),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
- EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
- MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+ EXPECT_EQ(Vocabulary::getSlotIndex(PtrVal),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
- EXPECT_EQ(Vocabulary::getNumericID(Arg),
- MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+ EXPECT_EQ(Vocabulary::getSlotIndex(Arg),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
+#undef EXPECTED_VOCAB_OPERAND_SLOT
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
// Test invalid opcode IDs
- EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
- EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
// Test invalid type IDs
- EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+ EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
"Invalid type ID");
EXPECT_DEATH(
- Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+ Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
"Invalid type ID");
}
#endif // NDEBUG
@@ -529,18 +549,46 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
- StringRef HalfTypeKey = Vocabulary::getStringKey(MaxOpcodes + 0);
- StringRef FloatTypeKey = Vocabulary::getStringKey(MaxOpcodes + 2);
- StringRef VoidTypeKey = Vocabulary::getStringKey(MaxOpcodes + 7);
- StringRef IntTypeKey = Vocabulary::getStringKey(MaxOpcodes + 12);
-
- EXPECT_EQ(HalfTypeKey, "FloatTy");
- EXPECT_EQ(FloatTypeKey, "FloatTy");
- EXPECT_EQ(VoidTypeKey, "VoidTy");
- EXPECT_EQ(IntTypeKey, "IntegerTy");
-
- StringRef FuncArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 0);
- StringRef PtrArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 1);
+#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \
+ EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \
+ Vocabulary::getVocabKeyForOpcode(NUM));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE
+
+ // Verify CanonicalTypeID -> string mapping
+#define EXPECT_CANONICAL_TYPE_NAME(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ(Vocabulary::getStringKey( \
+ MaxOpcodes + static_cast<unsigned>( \
+ Vocabulary::CanonicalTypeID::CanonEnum)), \
+ CanonStr);
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_CANONICAL_TYPE_NAME)
+
+#undef EXPECT_CANONICAL_TYPE_NAME
+
+#define HANDLE_OPERAND_KINDS(X) \
+ X(FunctionID, "Function") \
+ X(PointerID, "Pointer") \
+ X(ConstantID, "Constant") \
+ X(VariableID, "Variable")
+
+#define EXPECT_OPERAND_KIND(EnumName, Str) \
+ EXPECT_EQ(Vocabulary::getStringKey( \
+ MaxOpcodes + MaxCanonicalTypeIDs + \
+ static_cast<unsigned>(Vocabulary::OperandKind::EnumName)), \
+ Str);
+
+ HANDLE_OPERAND_KINDS(EXPECT_OPERAND_KIND)
+
+#undef EXPECT_OPERAND_KIND
+#undef HANDLE_OPERAND_KINDS
+
+ StringRef FuncArgKey =
+ Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 0);
+ StringRef PtrArgKey =
+ Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
}
@@ -578,39 +626,14 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::VoidTyID)),
- "VoidTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::IntegerTyID)),
- "IntegerTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::FloatTyID)),
- "FloatTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::PointerTyID)),
- "PointerTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::FunctionTyID)),
- "FunctionTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::StructTyID)),
- "StructTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::ArrayTyID)),
- "ArrayTy");
- EXPECT_EQ(Vocabulary::getStringKey(
- MaxOpcodes + static_cast<unsigned>(Type::FixedVectorTyID)),
- "VectorTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::LabelTyID)),
- "LabelTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::TokenTyID)),
- "TokenTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::MetadataTyID)),
- "MetadataTy");
+#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ( \
+ Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \
+ CanonStr);
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
+
+#undef EXPECT_TYPE_TO_CANONICAL
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
More information about the llvm-commits
mailing list