[llvm] [IR2Vec] Refactor vocabulary to use canonical type IDs (PR #155323)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 29 14:30:25 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/155323
>From cb881b61d370c8ac4d1cb00dce686d193ef1a0f8 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 25 Aug 2025 22:58:43 +0000
Subject: [PATCH] Canonicalized type
---
llvm/include/llvm/Analysis/IR2Vec.h | 135 +++++++++--
llvm/lib/Analysis/IR2Vec.cpp | 164 ++++++--------
.../Inputs/reference_default_vocab_print.txt | 11 +-
.../Inputs/reference_wtd1_vocab_print.txt | 11 +-
.../Inputs/reference_wtd2_vocab_print.txt | 11 +-
llvm/test/tools/llvm-ir2vec/entities.ll | 41 ++--
llvm/test/tools/llvm-ir2vec/triplets.ll | 58 ++---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 14 +-
llvm/unittests/Analysis/IR2VecTest.cpp | 213 ++++++++++--------
9 files changed, 350 insertions(+), 308 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index caa816e2fd76d..c42ca779e097c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
+#include <array>
#include <map>
namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// Class for storing and accessing the IR2Vec vocabulary.
-/// Encapsulates all vocabulary-related constants, logic, and access methods.
+/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
+/// seed embeddings are the initial learned representations of the entities
+/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
+/// seed embeddings.
+///
+/// The vocabulary contains the seed embeddings for three types of entities:
+/// instruction opcodes, types, and operands. Types are grouped/canonicalized
+/// for better learning (e.g., all float variants map to FloatTy). The
+/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
+/// handle all the known LLVM IR opcodes, types and operands.
+///
+/// This class helps populate the seed embeddings in an internal vector-based
+/// ADT. It provides logic to map every IR entity to a specific slot index or
+/// position in this vector, enabling O(1) embedding lookup while avoiding
+/// unnecessary computations involving string based lookups while generating the
+/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;
+public:
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction opcodes
+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+
+ /// Canonical type IDs supported by IR2Vec Vocabulary
+ enum class CanonicalTypeID : unsigned {
+ FloatTy,
+ VoidTy,
+ LabelTy,
+ MetadataTy,
+ VectorTy,
+ TokenTy,
+ IntegerTy,
+ FunctionTy,
+ PointerTy,
+ StructTy,
+ ArrayTy,
+ UnknownTy,
+ MaxCanonicalType
+ };
+
/// Operand kinds supported by IR2Vec Vocabulary
enum class OperandKind : unsigned {
FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
VariableID,
MaxOperandKind
};
- /// String mappings for OperandKind values
- static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
- "Constant", "Variable"};
- static_assert(std::size(OperandKindNames) ==
- static_cast<unsigned>(OperandKind::MaxOperandKind),
- "OperandKindNames array size must match MaxOperandKind");
-public:
/// Vocabulary layout constants
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
#include "llvm/IR/Instruction.def"
#undef LAST_OTHER_INST
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
+ static constexpr unsigned MaxCanonicalTypeIDs =
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
@@ -174,33 +208,31 @@ class Vocabulary {
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- LLVM_ABI size_t size() const;
+ /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
- static size_t expectedSize() {
- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
- }
-
- /// Helper function to get vocabulary key for a given Opcode
+ /// Function to get vocabulary key for a given Opcode
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
- /// Helper function to get vocabulary key for a given TypeID
+ /// Function to get vocabulary key for a given TypeID
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
- /// Helper function to get vocabulary key for a given OperandKind
+ /// Function to get vocabulary key for a given OperandKind
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
- /// Helper function to classify an operand into OperandKind
+ /// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
- /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
- LLVM_ABI static unsigned getNumericID(unsigned Opcode);
- LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
- LLVM_ABI static unsigned getNumericID(const Value *Op);
+ /// Functions to return the slot index or position of a given Opcode, TypeID,
+ /// or OperandKind in the vocabulary.
+ LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
+ LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
+ LLVM_ABI static unsigned getSlotIndex(const Value *Op);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
- LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
+
+private:
+ constexpr static unsigned NumCanonicalEntries =
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+
+ /// String mappings for CanonicalTypeID values
+ static constexpr StringLiteral CanonicalTypeNames[] = {
+ "FloatTy", "VoidTy", "LabelTy", "MetadataTy",
+ "VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
+ "PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
+ static_assert(std::size(CanonicalTypeNames) ==
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
+ "CanonicalTypeNames array size must match MaxCanonicalType");
+
+ /// String mappings for OperandKind values
+ static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
+ "Constant", "Variable"};
+ static_assert(std::size(OperandKindNames) ==
+ static_cast<unsigned>(OperandKind::MaxOperandKind),
+ "OperandKindNames array size must match MaxOperandKind");
+
+ /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
+ /// corresponding mapping here in the same order as enum Type::TypeID.
+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
+ CanonicalTypeID::FloatTy, // HalfTyID = 0
+ CanonicalTypeID::FloatTy, // BFloatTyID
+ CanonicalTypeID::FloatTy, // FloatTyID
+ CanonicalTypeID::FloatTy, // DoubleTyID
+ CanonicalTypeID::FloatTy, // X86_FP80TyID
+ CanonicalTypeID::FloatTy, // FP128TyID
+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
+ CanonicalTypeID::VoidTy, // VoidTyID
+ CanonicalTypeID::LabelTy, // LabelTyID
+ CanonicalTypeID::MetadataTy, // MetadataTyID
+ CanonicalTypeID::VectorTy, // X86_AMXTyID
+ CanonicalTypeID::TokenTy, // TokenTyID
+ CanonicalTypeID::IntegerTy, // IntegerTyID
+ CanonicalTypeID::FunctionTy, // FunctionTyID
+ CanonicalTypeID::PointerTy, // PointerTyID
+ CanonicalTypeID::StructTy, // StructTyID
+ CanonicalTypeID::ArrayTy, // ArrayTyID
+ CanonicalTypeID::VectorTy, // FixedVectorTyID
+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
+ CanonicalTypeID::PointerTy, // TypedPointerTyID
+ CanonicalTypeID::UnknownTy // TargetExtTyID
+ }};
+ static_assert(TypeIDMapping.size() == MaxTypeIDs,
+ "TypeIDMapping must cover all Type::TypeID values");
+
+ /// Function to get vocabulary key for canonical type by enum
+ LLVM_ABI static StringRef
+ getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+
+ /// Function to convert TypeID to CanonicalTypeID
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
};
/// Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
- /// Helper function to compute embeddings. It generates embeddings for all
+ /// Function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F.
void computeEmbeddings() const;
- /// Helper function to compute the embedding for a given basic block.
+ /// Function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index eb54f90a75488..886d22e424ea2 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,7 +32,7 @@ using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
- "Number of lookups to entites not present in the vocabulary");
+ "Number of lookups to entities not present in the vocabulary");
namespace llvm {
namespace ir2vec {
@@ -213,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
- ArgEmb += Vocab[Op];
+ ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
InstVecMap[&I] = InstVector;
@@ -242,8 +242,8 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// If the operand is not defined by an instruction, we use the vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
- << *Op << "=" << Vocab[Op][0] << "\n");
- ArgEmb += Vocab[Op];
+ << *Op << "=" << Vocab[*Op][0] << "\n");
+ ArgEmb += Vocab[*Op];
}
}
// Create the instruction vector by combining opcode, type, and arguments
@@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
: Vocab(std::move(Vocab)), Valid(true) {}
bool Vocabulary::isValid() const {
- return Vocab.size() == Vocabulary::expectedSize() && Valid;
-}
-
-size_t Vocabulary::size() const {
- assert(Valid && "IR2Vec Vocabulary is invalid");
- return Vocab.size();
+ return Vocab.size() == NumCanonicalEntries && Valid;
}
unsigned Vocabulary::getDimension() const {
@@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const {
return Vocab[0].size();
}
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Vocab[Opcode - 1];
+ return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+}
+
+unsigned Vocabulary::getSlotIndex(const Value *Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+ return Vocab[getSlotIndex(Opcode)];
}
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
- assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
- return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+ return Vocab[getSlotIndex(TypeID)];
}
-const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
- OperandKind ArgKind = getOperandKind(Arg);
- return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
+ return Vocab[getSlotIndex(&Arg)];
}
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
return "UnknownOpcode";
}
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+ unsigned Index = static_cast<unsigned>(CType);
+ assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+ return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+ unsigned Index = static_cast<unsigned>(TypeID);
+ assert(Index < MaxTypeIDs && "Invalid TypeID");
+ return TypeIDMapping[Index];
+}
+
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
- switch (TypeID) {
- case Type::VoidTyID:
- return "VoidTy";
- case Type::HalfTyID:
- case Type::BFloatTyID:
- case Type::FloatTyID:
- case Type::DoubleTyID:
- case Type::X86_FP80TyID:
- case Type::FP128TyID:
- case Type::PPC_FP128TyID:
- return "FloatTy";
- case Type::IntegerTyID:
- return "IntegerTy";
- case Type::FunctionTyID:
- return "FunctionTy";
- case Type::StructTyID:
- return "StructTy";
- case Type::ArrayTyID:
- return "ArrayTy";
- case Type::PointerTyID:
- case Type::TypedPointerTyID:
- return "PointerTy";
- case Type::FixedVectorTyID:
- case Type::ScalableVectorTyID:
- return "VectorTy";
- case Type::LabelTyID:
- return "LabelTy";
- case Type::TokenTyID:
- return "TokenTy";
- case Type::MetadataTyID:
- return "MetadataTy";
- case Type::X86_AMXTyID:
- case Type::TargetExtTyID:
- return "UnknownTy";
- }
- return "UnknownTy";
+ return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
}
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
return OperandKindNames[Index];
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operand
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
- Vocabulary::MaxOperandKinds)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
- DummyVal += 0.1f;
- }
- return DummyVocab;
-}
-
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
@@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
- assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
- return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
- unsigned Index = static_cast<unsigned>(getOperandKind(Op));
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxTypeIDs + Index;
-}
-
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < Vocabulary::expectedSize() &&
- "Position out of bounds in vocabulary");
+ assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxTypeIDs)
- return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+ if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ return getVocabKeyForCanonicalTypeID(
+ static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+ static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+ VocabVector DummyVocab;
+ DummyVocab.reserve(NumCanonicalEntries);
+ float DummyVal = 0.1f;
+ // Create a dummy vocabulary with entries for all opcodes, types, and
+ // operands
+ for ([[maybe_unused]] unsigned _ :
+ seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+ Vocabulary::MaxOperandKinds)) {
+ DummyVocab.push_back(Embedding(Dim, DummyVal));
+ DummyVal += 0.1f;
+ }
+ return DummyVocab;
+}
+
// ==----------------------------------------------------------------------===//
// IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
@@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim, 0));
+ NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
NumericOpcodeEmbeddings.end());
- // Handle Types
- std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
+ // Handle Types - only canonical types are present in vocabulary
+ std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim, 0));
- for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
- StringRef VocabKey =
- Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+ NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+ for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+ StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+ static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
- NumericTypeEmbeddings[TypeID] = It->second;
+ NumericTypeEmbeddings[CTypeID] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
@@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim, 0));
+ NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index 1b9b3c2acd8a5..df7769c9c6a65 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 129.00 130.00 ]
Key: LandingPad: [ 131.00 132.00 ]
Key: Freeze: [ 133.00 134.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index 9673e7f23fa5c..f3ce809fd2fd2 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 64.50 65.00 ]
Key: LandingPad: [ 65.50 66.00 ]
Key: Freeze: [ 66.50 67.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 1f575d29092dd..72b25b9bd3d9c 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 12.90 13.00 ]
Key: LandingPad: [ 13.10 13.20 ]
Key: Freeze: [ 13.30 13.40 ]
Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
Key: VoidTy: [ 0.00 0.00 ]
Key: LabelTy: [ 0.00 0.00 ]
Key: MetadataTy: [ 0.00 0.00 ]
-Key: UnknownTy: [ 0.00 0.00 ]
+Key: VectorTy: [ 0.00 0.00 ]
Key: TokenTy: [ 0.00 0.00 ]
Key: IntegerTy: [ 0.00 0.00 ]
Key: FunctionTy: [ 0.00 0.00 ]
Key: PointerTy: [ 0.00 0.00 ]
Key: StructTy: [ 0.00 0.00 ]
Key: ArrayTy: [ 0.00 0.00 ]
-Key: VectorTy: [ 0.00 0.00 ]
-Key: VectorTy: [ 0.00 0.00 ]
-Key: PointerTy: [ 0.00 0.00 ]
Key: UnknownTy: [ 0.00 0.00 ]
Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4ed6400d7a195..4b51adf30bf74 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 93
+CHECK: 84
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -70,27 +70,18 @@ CHECK-NEXT: InsertValue 65
CHECK-NEXT: LandingPad 66
CHECK-NEXT: Freeze 67
CHECK-NEXT: FloatTy 68
-CHECK-NEXT: FloatTy 69
-CHECK-NEXT: FloatTy 70
-CHECK-NEXT: FloatTy 71
-CHECK-NEXT: FloatTy 72
-CHECK-NEXT: FloatTy 73
-CHECK-NEXT: FloatTy 74
-CHECK-NEXT: VoidTy 75
-CHECK-NEXT: LabelTy 76
-CHECK-NEXT: MetadataTy 77
-CHECK-NEXT: UnknownTy 78
-CHECK-NEXT: TokenTy 79
-CHECK-NEXT: IntegerTy 80
-CHECK-NEXT: FunctionTy 81
-CHECK-NEXT: PointerTy 82
-CHECK-NEXT: StructTy 83
-CHECK-NEXT: ArrayTy 84
-CHECK-NEXT: VectorTy 85
-CHECK-NEXT: VectorTy 86
-CHECK-NEXT: PointerTy 87
-CHECK-NEXT: UnknownTy 88
-CHECK-NEXT: Function 89
-CHECK-NEXT: Pointer 90
-CHECK-NEXT: Constant 91
-CHECK-NEXT: Variable 92
+CHECK-NEXT: VoidTy 69
+CHECK-NEXT: LabelTy 70
+CHECK-NEXT: MetadataTy 71
+CHECK-NEXT: VectorTy 72
+CHECK-NEXT: TokenTy 73
+CHECK-NEXT: IntegerTy 74
+CHECK-NEXT: FunctionTy 75
+CHECK-NEXT: PointerTy 76
+CHECK-NEXT: StructTy 77
+CHECK-NEXT: ArrayTy 78
+CHECK-NEXT: UnknownTy 79
+CHECK-NEXT: Function 80
+CHECK-NEXT: Pointer 81
+CHECK-NEXT: Constant 82
+CHECK-NEXT: Variable 83
diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll
index 6f64bab888f6b..7b476f60a07b3 100644
--- a/llvm/test/tools/llvm-ir2vec/triplets.ll
+++ b/llvm/test/tools/llvm-ir2vec/triplets.ll
@@ -25,41 +25,41 @@ entry:
}
; TRIPLETS: MAX_RELATION=3
-; TRIPLETS-NEXT: 12 80 0
-; TRIPLETS-NEXT: 12 92 2
-; TRIPLETS-NEXT: 12 92 3
+; TRIPLETS-NEXT: 12 74 0
+; TRIPLETS-NEXT: 12 83 2
+; TRIPLETS-NEXT: 12 83 3
; TRIPLETS-NEXT: 12 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
-; TRIPLETS-NEXT: 16 80 0
-; TRIPLETS-NEXT: 16 92 2
-; TRIPLETS-NEXT: 16 92 3
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
+; TRIPLETS-NEXT: 16 74 0
+; TRIPLETS-NEXT: 16 83 2
+; TRIPLETS-NEXT: 16 83 3
; TRIPLETS-NEXT: 16 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
-; TRIPLETS-NEXT: 30 82 0
-; TRIPLETS-NEXT: 30 91 2
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
+; TRIPLETS-NEXT: 30 76 0
+; TRIPLETS-NEXT: 30 82 2
; TRIPLETS-NEXT: 30 30 1
-; TRIPLETS-NEXT: 30 82 0
-; TRIPLETS-NEXT: 30 91 2
+; TRIPLETS-NEXT: 30 76 0
+; TRIPLETS-NEXT: 30 82 2
; TRIPLETS-NEXT: 30 32 1
-; TRIPLETS-NEXT: 32 75 0
-; TRIPLETS-NEXT: 32 92 2
-; TRIPLETS-NEXT: 32 90 3
+; TRIPLETS-NEXT: 32 69 0
+; TRIPLETS-NEXT: 32 83 2
+; TRIPLETS-NEXT: 32 81 3
; TRIPLETS-NEXT: 32 32 1
-; TRIPLETS-NEXT: 32 75 0
-; TRIPLETS-NEXT: 32 92 2
-; TRIPLETS-NEXT: 32 90 3
+; TRIPLETS-NEXT: 32 69 0
+; TRIPLETS-NEXT: 32 83 2
+; TRIPLETS-NEXT: 32 81 3
; TRIPLETS-NEXT: 32 31 1
-; TRIPLETS-NEXT: 31 80 0
-; TRIPLETS-NEXT: 31 90 2
+; TRIPLETS-NEXT: 31 74 0
+; TRIPLETS-NEXT: 31 81 2
; TRIPLETS-NEXT: 31 31 1
-; TRIPLETS-NEXT: 31 80 0
-; TRIPLETS-NEXT: 31 90 2
+; TRIPLETS-NEXT: 31 74 0
+; TRIPLETS-NEXT: 31 81 2
; TRIPLETS-NEXT: 31 12 1
-; TRIPLETS-NEXT: 12 80 0
-; TRIPLETS-NEXT: 12 92 2
-; TRIPLETS-NEXT: 12 92 3
+; TRIPLETS-NEXT: 12 74 0
+; TRIPLETS-NEXT: 12 83 2
+; TRIPLETS-NEXT: 12 83 3
; TRIPLETS-NEXT: 12 0 1
-; TRIPLETS-NEXT: 0 75 0
-; TRIPLETS-NEXT: 0 92 2
+; TRIPLETS-NEXT: 0 69 0
+; TRIPLETS-NEXT: 0 83 2
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index c065aaeedd395..461ded77d9609 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getNumericID(I.getOpcode());
- unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID());
+ unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getNumericID(U.get());
+ unsigned OperandID = Vocabulary::getSlotIndex(U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
@@ -211,13 +211,7 @@ class IR2VecTool {
/// Dump entity ID to string mappings
static void generateEntityMappings(raw_ostream &OS) {
- // FIXME: Currently, the generated entity mappings are not one-to-one;
- // Multiple TypeIDs map to same string key (Like Half, BFloat, etc. map to
- // FloatTy). This would hinder learning good seed embeddings.
- // We should fix this in the future by ensuring unique string keys either by
- // post-processing here without changing the mapping in ir2vec::Vocabulary,
- // or by changing the Vocabulary generation logic to ensure unique keys.
- auto EntityLen = Vocabulary::expectedSize();
+ auto EntityLen = Vocabulary::getCanonicalSize();
OS << EntityLen << "\n";
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index f0c81e160ca15..9f5428758d64c 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -336,8 +336,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
EXPECT_EQ(AddEmb.size(), 2u);
EXPECT_EQ(RetEmb.size(), 2u);
- EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 27.9)));
- EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0)));
+ EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
}
TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
@@ -353,8 +353,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
- EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9)));
- EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6)));
+ EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
}
TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
@@ -367,9 +367,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);
- // BB vector should be sum of add and ret: {27.9, 27.9} + {17.0, 17.0} =
- // {44.9, 44.9}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9)));
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+ // {41.0, 41.0}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
@@ -382,9 +382,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);
- // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} =
- // {63.5, 63.5}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5)));
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+ // {58.1, 58.1}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,7 +394,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
- EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9)));
+ EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
@@ -404,7 +404,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
- EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5)));
+ EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
}
TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
@@ -415,8 +415,8 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
EXPECT_EQ(FuncVec.size(), 2u);
- // Function vector should match BB vector (only one BB): {44.9, 44.9}
- EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9)));
+ // Function vector should match BB vector (only one BB): {41.0, 41.0}
+ EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 41.0)));
}
TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
@@ -426,24 +426,40 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
- // Function vector should match BB vector (only one BB): {63.5, 63.5}
- EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5)));
+ // Function vector should match BB vector (only one BB): {58.1, 58.1}
+ EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1)));
}
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
+static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
+// names and their canonical string keys.
+#define IR2VEC_HANDLE_TYPE_BIMAP(X) \
+ X(VoidTyID, VoidTy, "VoidTy") \
+ X(IntegerTyID, IntegerTy, "IntegerTy") \
+ X(FloatTyID, FloatTy, "FloatTy") \
+ X(PointerTyID, PointerTy, "PointerTy") \
+ X(FunctionTyID, FunctionTy, "FunctionTy") \
+ X(StructTyID, StructTy, "StructTy") \
+ X(ArrayTyID, ArrayTy, "ArrayTy") \
+ X(FixedVectorTyID, VectorTy, "VectorTy") \
+ X(LabelTyID, LabelTy, "LabelTy") \
+ X(TokenTyID, TokenTy, "TokenTy") \
+ X(MetadataTyID, MetadataTy, "MetadataTy")
+
TEST(IR2VecVocabularyTest, DummyVocabTest) {
for (unsigned Dim = 1; Dim <= 10; ++Dim) {
auto VocabVec = Vocabulary::createDummyVocabForTest(Dim);
-
+ auto VocabVecSize = VocabVec.size();
// All embeddings should have the same dimension
for (const auto &Emb : VocabVec)
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
- EXPECT_EQ(VocabVec.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+ EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
auto ExpectedVocab = VocabVec;
@@ -454,7 +470,7 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
Vocabulary Result = VocabAnalysis.run(TestMod, MAM);
EXPECT_TRUE(Result.isValid());
EXPECT_EQ(Result.getDimension(), Dim);
- EXPECT_EQ(Result.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+ EXPECT_EQ(Result.getCanonicalSize(), VocabVecSize);
unsigned CurPos = 0;
for (const auto &Entry : Result)
@@ -462,64 +478,68 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
}
-TEST(IR2VecVocabularyTest, NumericIDMap) {
- // Test getNumericID for opcodes
- EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
- EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
- EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
-
- // Test getNumericID for Type IDs
- EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
- MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
- MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
- MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
- MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
- EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
- MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
-
- // Test getNumericID for Value operands
+TEST(IR2VecVocabularyTest, SlotIdxMapping) {
+ // Test getSlotIndex for Opcodes
+#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \
+ EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE_SLOT
+
+ // Test getSlotIndex for Types
+#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \
+ MaxOpcodes + static_cast<unsigned>( \
+ Vocabulary::CanonicalTypeID::CanonEnum));
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_SLOT)
+
+#undef EXPECT_TYPE_SLOT
+
+ // Test getSlotIndex for Value operands
LLVMContext Ctx;
Module M("TestM", Ctx);
FunctionType *FTy =
FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
+ MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
- EXPECT_EQ(Vocabulary::getNumericID(F),
- MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+ EXPECT_EQ(Vocabulary::getSlotIndex(F),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- EXPECT_EQ(Vocabulary::getNumericID(C),
- MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+ EXPECT_EQ(Vocabulary::getSlotIndex(C),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
- EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
- MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+ EXPECT_EQ(Vocabulary::getSlotIndex(PtrVal),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
- EXPECT_EQ(Vocabulary::getNumericID(Arg),
- MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+ EXPECT_EQ(Vocabulary::getSlotIndex(Arg),
+ EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
+#undef EXPECTED_VOCAB_OPERAND_SLOT
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
// Test invalid opcode IDs
- EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
- EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
// Test invalid type IDs
- EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+ EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
"Invalid type ID");
EXPECT_DEATH(
- Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+ Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
"Invalid type ID");
}
#endif // NDEBUG
@@ -529,18 +549,46 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
- StringRef HalfTypeKey = Vocabulary::getStringKey(MaxOpcodes + 0);
- StringRef FloatTypeKey = Vocabulary::getStringKey(MaxOpcodes + 2);
- StringRef VoidTypeKey = Vocabulary::getStringKey(MaxOpcodes + 7);
- StringRef IntTypeKey = Vocabulary::getStringKey(MaxOpcodes + 12);
-
- EXPECT_EQ(HalfTypeKey, "FloatTy");
- EXPECT_EQ(FloatTypeKey, "FloatTy");
- EXPECT_EQ(VoidTypeKey, "VoidTy");
- EXPECT_EQ(IntTypeKey, "IntegerTy");
-
- StringRef FuncArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 0);
- StringRef PtrArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 1);
+#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \
+ EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \
+ Vocabulary::getVocabKeyForOpcode(NUM));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE
+
+ // Verify CanonicalTypeID -> string mapping
+#define EXPECT_CANONICAL_TYPE_NAME(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ(Vocabulary::getStringKey( \
+ MaxOpcodes + static_cast<unsigned>( \
+ Vocabulary::CanonicalTypeID::CanonEnum)), \
+ CanonStr);
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_CANONICAL_TYPE_NAME)
+
+#undef EXPECT_CANONICAL_TYPE_NAME
+
+#define HANDLE_OPERAND_KINDS(X) \
+ X(FunctionID, "Function") \
+ X(PointerID, "Pointer") \
+ X(ConstantID, "Constant") \
+ X(VariableID, "Variable")
+
+#define EXPECT_OPERAND_KIND(EnumName, Str) \
+ EXPECT_EQ(Vocabulary::getStringKey( \
+ MaxOpcodes + MaxCanonicalTypeIDs + \
+ static_cast<unsigned>(Vocabulary::OperandKind::EnumName)), \
+ Str);
+
+ HANDLE_OPERAND_KINDS(EXPECT_OPERAND_KIND)
+
+#undef EXPECT_OPERAND_KIND
+#undef HANDLE_OPERAND_KINDS
+
+ StringRef FuncArgKey =
+ Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 0);
+ StringRef PtrArgKey =
+ Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
}
@@ -578,39 +626,14 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::VoidTyID)),
- "VoidTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::IntegerTyID)),
- "IntegerTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::FloatTyID)),
- "FloatTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::PointerTyID)),
- "PointerTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::FunctionTyID)),
- "FunctionTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::StructTyID)),
- "StructTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::ArrayTyID)),
- "ArrayTy");
- EXPECT_EQ(Vocabulary::getStringKey(
- MaxOpcodes + static_cast<unsigned>(Type::FixedVectorTyID)),
- "VectorTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::LabelTyID)),
- "LabelTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::TokenTyID)),
- "TokenTy");
- EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
- static_cast<unsigned>(Type::MetadataTyID)),
- "MetadataTy");
+#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \
+ EXPECT_EQ( \
+ Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \
+ CanonStr);
+
+ IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
+
+#undef EXPECT_TYPE_TO_CANONICAL
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
More information about the llvm-commits
mailing list