[llvm-branch-commits] [llvm] [IR2Vec] Refactor vocabulary to use	canonical type IDs (PR #155323)
    S. VenkataKeerthy via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Thu Aug 28 16:04:32 PDT 2025
    
    
  
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/155323
>From 01b9019f3409ce74e0bfcf24538f3d3136235de2 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 25 Aug 2025 22:58:43 +0000
Subject: [PATCH] Canonicalized type
---
 llvm/include/llvm/Analysis/IR2Vec.h           | 135 +++++++++--
 llvm/lib/Analysis/IR2Vec.cpp                  | 164 ++++++--------
 .../Inputs/reference_default_vocab_print.txt  |  11 +-
 .../Inputs/reference_wtd1_vocab_print.txt     |  11 +-
 .../Inputs/reference_wtd2_vocab_print.txt     |  11 +-
 llvm/test/tools/llvm-ir2vec/entities.ll       |  41 ++--
 llvm/test/tools/llvm-ir2vec/triplets.ll       |  58 ++---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        |  14 +-
 llvm/unittests/Analysis/IR2VecTest.cpp        | 213 ++++++++++--------
 9 files changed, 350 insertions(+), 308 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index caa816e2fd76d..c42ca779e097c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/JSON.h"
+#include <array>
 #include <map>
 
 namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 
 /// Class for storing and accessing the IR2Vec vocabulary.
-/// Encapsulates all vocabulary-related constants, logic, and access methods.
+/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
+/// seed embeddings are the initial learned representations of the entities
+/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
+/// seed embeddings.
+///
+/// The vocabulary contains the seed embeddings for three types of entities:
+/// instruction opcodes, types, and operands. Types are grouped/canonicalized
+/// for better learning (e.g., all float variants map to FloatTy). The
+/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
+/// handle all the known LLVM IR opcodes, types and operands.
+///
+/// This class helps populate the seed embeddings in an internal vector-based
+/// ADT. It provides logic to map every IR entity to a specific slot index or
+/// position in this vector, enabling O(1) embedding lookup while avoiding
+/// unnecessary computations involving string based lookups while generating the
+/// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
   bool Valid = false;
 
+public:
+  // Slot layout:
+  // [0 .. MaxOpcodes-1]               => Instruction opcodes
+  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
+  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+
+  /// Canonical type IDs supported by IR2Vec Vocabulary
+  enum class CanonicalTypeID : unsigned {
+    FloatTy,
+    VoidTy,
+    LabelTy,
+    MetadataTy,
+    VectorTy,
+    TokenTy,
+    IntegerTy,
+    FunctionTy,
+    PointerTy,
+    StructTy,
+    ArrayTy,
+    UnknownTy,
+    MaxCanonicalType
+  };
+
   /// Operand kinds supported by IR2Vec Vocabulary
   enum class OperandKind : unsigned {
     FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
     VariableID,
     MaxOperandKind
   };
-  /// String mappings for OperandKind values
-  static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
-                                                       "Constant", "Variable"};
-  static_assert(std::size(OperandKindNames) ==
-                    static_cast<unsigned>(OperandKind::MaxOperandKind),
-                "OperandKindNames array size must match MaxOperandKind");
 
-public:
   /// Vocabulary layout constants
 #define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
 #include "llvm/IR/Instruction.def"
 #undef LAST_OTHER_INST
 
   static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
+  static constexpr unsigned MaxCanonicalTypeIDs =
+      static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
 
@@ -174,33 +208,31 @@ class Vocabulary {
 
   LLVM_ABI bool isValid() const;
   LLVM_ABI unsigned getDimension() const;
-  LLVM_ABI size_t size() const;
+  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
-  static size_t expectedSize() {
-    return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
-  }
-
-  /// Helper function to get vocabulary key for a given Opcode
+  /// Function to get vocabulary key for a given Opcode
   LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
 
-  /// Helper function to get vocabulary key for a given TypeID
+  /// Function to get vocabulary key for a given TypeID
   LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
 
-  /// Helper function to get vocabulary key for a given OperandKind
+  /// Function to get vocabulary key for a given OperandKind
   LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
 
-  /// Helper function to classify an operand into OperandKind
+  /// Function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
-  /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
-  LLVM_ABI static unsigned getNumericID(unsigned Opcode);
-  LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
-  LLVM_ABI static unsigned getNumericID(const Value *Op);
+  /// Functions to return the slot index or position of a given Opcode, TypeID,
+  /// or OperandKind in the vocabulary.
+  LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
+  LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
+  LLVM_ABI static unsigned getSlotIndex(const Value *Op);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
   LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
-  LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
+  LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
 
   /// Const Iterator type aliases
   using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
 
   LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
                            ModuleAnalysisManager::Invalidator &Inv) const;
+
+private:
+  constexpr static unsigned NumCanonicalEntries =
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+
+  /// String mappings for CanonicalTypeID values
+  static constexpr StringLiteral CanonicalTypeNames[] = {
+      "FloatTy",   "VoidTy",   "LabelTy",   "MetadataTy",
+      "VectorTy",  "TokenTy",  "IntegerTy", "FunctionTy",
+      "PointerTy", "StructTy", "ArrayTy",   "UnknownTy"};
+  static_assert(std::size(CanonicalTypeNames) ==
+                    static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
+                "CanonicalTypeNames array size must match MaxCanonicalType");
+
+  /// String mappings for OperandKind values
+  static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
+                                                       "Constant", "Variable"};
+  static_assert(std::size(OperandKindNames) ==
+                    static_cast<unsigned>(OperandKind::MaxOperandKind),
+                "OperandKindNames array size must match MaxOperandKind");
+
+  /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
+  /// corresponding mapping here in the same order as enum Type::TypeID.
+  static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
+      CanonicalTypeID::FloatTy,    // HalfTyID = 0
+      CanonicalTypeID::FloatTy,    // BFloatTyID
+      CanonicalTypeID::FloatTy,    // FloatTyID
+      CanonicalTypeID::FloatTy,    // DoubleTyID
+      CanonicalTypeID::FloatTy,    // X86_FP80TyID
+      CanonicalTypeID::FloatTy,    // FP128TyID
+      CanonicalTypeID::FloatTy,    // PPC_FP128TyID
+      CanonicalTypeID::VoidTy,     // VoidTyID
+      CanonicalTypeID::LabelTy,    // LabelTyID
+      CanonicalTypeID::MetadataTy, // MetadataTyID
+      CanonicalTypeID::VectorTy,   // X86_AMXTyID
+      CanonicalTypeID::TokenTy,    // TokenTyID
+      CanonicalTypeID::IntegerTy,  // IntegerTyID
+      CanonicalTypeID::FunctionTy, // FunctionTyID
+      CanonicalTypeID::PointerTy,  // PointerTyID
+      CanonicalTypeID::StructTy,   // StructTyID
+      CanonicalTypeID::ArrayTy,    // ArrayTyID
+      CanonicalTypeID::VectorTy,   // FixedVectorTyID
+      CanonicalTypeID::VectorTy,   // ScalableVectorTyID
+      CanonicalTypeID::PointerTy,  // TypedPointerTyID
+      CanonicalTypeID::UnknownTy   // TargetExtTyID
+  }};
+  static_assert(TypeIDMapping.size() == MaxTypeIDs,
+                "TypeIDMapping must cover all Type::TypeID values");
+
+  /// Function to get vocabulary key for canonical type by enum
+  LLVM_ABI static StringRef
+  getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+
+  /// Function to convert TypeID to CanonicalTypeID
+  LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
 
   LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
 
-  /// Helper function to compute embeddings. It generates embeddings for all
+  /// Function to compute embeddings. It generates embeddings for all
   /// the instructions and basic blocks in the function F.
   void computeEmbeddings() const;
 
-  /// Helper function to compute the embedding for a given basic block.
+  /// Function to compute the embedding for a given basic block.
   /// Specific to the kind of embeddings being computed.
   virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
 
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index eb54f90a75488..886d22e424ea2 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,7 +32,7 @@ using namespace ir2vec;
 #define DEBUG_TYPE "ir2vec"
 
 STATISTIC(VocabMissCounter,
-          "Number of lookups to entites not present in the vocabulary");
+          "Number of lookups to entities not present in the vocabulary");
 
 namespace llvm {
 namespace ir2vec {
@@ -213,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   for (const auto &I : BB.instructionsWithoutDebug()) {
     Embedding ArgEmb(Dimension, 0);
     for (const auto &Op : I.operands())
-      ArgEmb += Vocab[Op];
+      ArgEmb += Vocab[*Op];
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
     InstVecMap[&I] = InstVector;
@@ -242,8 +242,8 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       // If the operand is not defined by an instruction, we use the vocabulary
       else {
         LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
-                          << *Op << "=" << Vocab[Op][0] << "\n");
-        ArgEmb += Vocab[Op];
+                          << *Op << "=" << Vocab[*Op][0] << "\n");
+        ArgEmb += Vocab[*Op];
       }
     }
     // Create the instruction vector by combining opcode, type, and arguments
@@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
     : Vocab(std::move(Vocab)), Valid(true) {}
 
 bool Vocabulary::isValid() const {
-  return Vocab.size() == Vocabulary::expectedSize() && Valid;
-}
-
-size_t Vocabulary::size() const {
-  assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocab.size();
+  return Vocab.size() == NumCanonicalEntries && Valid;
 }
 
 unsigned Vocabulary::getDimension() const {
@@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const {
   return Vocab[0].size();
 }
 
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
-  return Vocab[Opcode - 1];
+  return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+  return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+}
+
+unsigned Vocabulary::getSlotIndex(const Value *Op) {
+  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+  assert(Index < MaxOperandKinds && "Invalid OperandKind");
+  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+  return Vocab[getSlotIndex(Opcode)];
 }
 
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
-  assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
-  return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+  return Vocab[getSlotIndex(TypeID)];
 }
 
-const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
-  OperandKind ArgKind = getOperandKind(Arg);
-  return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
+  return Vocab[getSlotIndex(&Arg)];
 }
 
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   return "UnknownOpcode";
 }
 
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+  unsigned Index = static_cast<unsigned>(CType);
+  assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+  return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+  unsigned Index = static_cast<unsigned>(TypeID);
+  assert(Index < MaxTypeIDs && "Invalid TypeID");
+  return TypeIDMapping[Index];
+}
+
 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
-  switch (TypeID) {
-  case Type::VoidTyID:
-    return "VoidTy";
-  case Type::HalfTyID:
-  case Type::BFloatTyID:
-  case Type::FloatTyID:
-  case Type::DoubleTyID:
-  case Type::X86_FP80TyID:
-  case Type::FP128TyID:
-  case Type::PPC_FP128TyID:
-    return "FloatTy";
-  case Type::IntegerTyID:
-    return "IntegerTy";
-  case Type::FunctionTyID:
-    return "FunctionTy";
-  case Type::StructTyID:
-    return "StructTy";
-  case Type::ArrayTyID:
-    return "ArrayTy";
-  case Type::PointerTyID:
-  case Type::TypedPointerTyID:
-    return "PointerTy";
-  case Type::FixedVectorTyID:
-  case Type::ScalableVectorTyID:
-    return "VectorTy";
-  case Type::LabelTyID:
-    return "LabelTy";
-  case Type::TokenTyID:
-    return "TokenTy";
-  case Type::MetadataTyID:
-    return "MetadataTy";
-  case Type::X86_AMXTyID:
-  case Type::TargetExtTyID:
-    return "UnknownTy";
-  }
-  return "UnknownTy";
+  return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
 }
 
 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
   return OperandKindNames[Index];
 }
 
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
-  VocabVector DummyVocab;
-  float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operand
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
-    DummyVocab.push_back(Embedding(Dim, DummyVal));
-    DummyVal += 0.1f;
-  }
-  return DummyVocab;
-}
-
 // Helper function to classify an operand into OperandKind
 Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   if (isa<Function>(Op))
@@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
-  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
-  return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
-  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
-  return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
-  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
-  assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxTypeIDs + Index;
-}
-
 StringRef Vocabulary::getStringKey(unsigned Pos) {
-  assert(Pos < Vocabulary::expectedSize() &&
-         "Position out of bounds in vocabulary");
+  assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxTypeIDs)
-    return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+    return getVocabKeyForCanonicalTypeID(
+        static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
   return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
   return !(PAC.preservedWhenStateless());
 }
 
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+  VocabVector DummyVocab;
+  DummyVocab.reserve(NumCanonicalEntries);
+  float DummyVal = 0.1f;
+  // Create a dummy vocabulary with entries for all opcodes, types, and
+  // operands
+  for ([[maybe_unused]] unsigned _ :
+       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+                   Vocabulary::MaxOperandKinds)) {
+    DummyVocab.push_back(Embedding(Dim, DummyVal));
+    DummyVal += 0.1f;
+  }
+  return DummyVocab;
+}
+
 // ==----------------------------------------------------------------------===//
 // IR2VecVocabAnalysis
 //===----------------------------------------------------------------------===//
@@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Opcodes
   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
                                                  Embedding(Dim, 0));
+  NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
   for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
     StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
     auto It = OpcVocab.find(VocabKey.str());
@@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
                NumericOpcodeEmbeddings.end());
 
-  // Handle Types
-  std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
+  // Handle Types - only canonical types are present in vocabulary
+  std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
                                                Embedding(Dim, 0));
-  for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
-    StringRef VocabKey =
-        Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+  NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+  for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+    StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+        static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
     if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
-      NumericTypeEmbeddings[TypeID] = It->second;
+      NumericTypeEmbeddings[CTypeID] = It->second;
       continue;
     }
     handleMissingEntity(VocabKey.str());
@@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Arguments/Operands
   std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
                                               Embedding(Dim, 0));
+  NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
   for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
     Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
     StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index 1b9b3c2acd8a5..df7769c9c6a65 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 129.00  130.00 ]
 Key: LandingPad:  [ 131.00  132.00 ]
 Key: Freeze:  [ 133.00  134.00 ]
 Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
 Key: VoidTy:  [ 1.50  2.00 ]
 Key: LabelTy:  [ 2.50  3.00 ]
 Key: MetadataTy:  [ 3.50  4.00 ]
-Key: UnknownTy:  [ 4.50  5.00 ]
+Key: VectorTy:  [ 11.50  12.00 ]
 Key: TokenTy:  [ 5.50  6.00 ]
 Key: IntegerTy:  [ 6.50  7.00 ]
 Key: FunctionTy:  [ 7.50  8.00 ]
 Key: PointerTy:  [ 8.50  9.00 ]
 Key: StructTy:  [ 9.50  10.00 ]
 Key: ArrayTy:  [ 10.50  11.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: PointerTy:  [ 8.50  9.00 ]
 Key: UnknownTy:  [ 4.50  5.00 ]
 Key: Function:  [ 0.20  0.40 ]
 Key: Pointer:  [ 0.60  0.80 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index 9673e7f23fa5c..f3ce809fd2fd2 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 64.50  65.00 ]
 Key: LandingPad:  [ 65.50  66.00 ]
 Key: Freeze:  [ 66.50  67.00 ]
 Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
 Key: VoidTy:  [ 1.50  2.00 ]
 Key: LabelTy:  [ 2.50  3.00 ]
 Key: MetadataTy:  [ 3.50  4.00 ]
-Key: UnknownTy:  [ 4.50  5.00 ]
+Key: VectorTy:  [ 11.50  12.00 ]
 Key: TokenTy:  [ 5.50  6.00 ]
 Key: IntegerTy:  [ 6.50  7.00 ]
 Key: FunctionTy:  [ 7.50  8.00 ]
 Key: PointerTy:  [ 8.50  9.00 ]
 Key: StructTy:  [ 9.50  10.00 ]
 Key: ArrayTy:  [ 10.50  11.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: PointerTy:  [ 8.50  9.00 ]
 Key: UnknownTy:  [ 4.50  5.00 ]
 Key: Function:  [ 0.50  1.00 ]
 Key: Pointer:  [ 1.50  2.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 1f575d29092dd..72b25b9bd3d9c 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 12.90  13.00 ]
 Key: LandingPad:  [ 13.10  13.20 ]
 Key: Freeze:  [ 13.30  13.40 ]
 Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
 Key: VoidTy:  [ 0.00  0.00 ]
 Key: LabelTy:  [ 0.00  0.00 ]
 Key: MetadataTy:  [ 0.00  0.00 ]
-Key: UnknownTy:  [ 0.00  0.00 ]
+Key: VectorTy:  [ 0.00  0.00 ]
 Key: TokenTy:  [ 0.00  0.00 ]
 Key: IntegerTy:  [ 0.00  0.00 ]
 Key: FunctionTy:  [ 0.00  0.00 ]
 Key: PointerTy:  [ 0.00  0.00 ]
 Key: StructTy:  [ 0.00  0.00 ]
 Key: ArrayTy:  [ 0.00  0.00 ]
-Key: VectorTy:  [ 0.00  0.00 ]
-Key: VectorTy:  [ 0.00  0.00 ]
-Key: PointerTy:  [ 0.00  0.00 ]
 Key: UnknownTy:  [ 0.00  0.00 ]
 Key: Function:  [ 0.00  0.00 ]
 Key: Pointer:  [ 0.00  0.00 ]
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4ed6400d7a195..4b51adf30bf74 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
 ; RUN: llvm-ir2vec entities | FileCheck %s
 
-CHECK: 93
+CHECK: 84
 CHECK-NEXT: Ret     0
 CHECK-NEXT: Br      1
 CHECK-NEXT: Switch  2
@@ -70,27 +70,18 @@ CHECK-NEXT: InsertValue     65
 CHECK-NEXT: LandingPad      66
 CHECK-NEXT: Freeze  67
 CHECK-NEXT: FloatTy 68
-CHECK-NEXT: FloatTy 69
-CHECK-NEXT: FloatTy 70
-CHECK-NEXT: FloatTy 71
-CHECK-NEXT: FloatTy 72
-CHECK-NEXT: FloatTy 73
-CHECK-NEXT: FloatTy 74
-CHECK-NEXT: VoidTy  75
-CHECK-NEXT: LabelTy 76
-CHECK-NEXT: MetadataTy      77
-CHECK-NEXT: UnknownTy       78
-CHECK-NEXT: TokenTy 79
-CHECK-NEXT: IntegerTy       80
-CHECK-NEXT: FunctionTy      81
-CHECK-NEXT: PointerTy       82
-CHECK-NEXT: StructTy        83
-CHECK-NEXT: ArrayTy 84
-CHECK-NEXT: VectorTy        85
-CHECK-NEXT: VectorTy        86
-CHECK-NEXT: PointerTy       87
-CHECK-NEXT: UnknownTy       88
-CHECK-NEXT: Function        89
-CHECK-NEXT: Pointer 90
-CHECK-NEXT: Constant        91
-CHECK-NEXT: Variable        92
+CHECK-NEXT: VoidTy  69
+CHECK-NEXT: LabelTy 70
+CHECK-NEXT: MetadataTy      71
+CHECK-NEXT: VectorTy        72
+CHECK-NEXT: TokenTy 73
+CHECK-NEXT: IntegerTy       74
+CHECK-NEXT: FunctionTy      75
+CHECK-NEXT: PointerTy       76
+CHECK-NEXT: StructTy        77
+CHECK-NEXT: ArrayTy 78
+CHECK-NEXT: UnknownTy       79
+CHECK-NEXT: Function        80
+CHECK-NEXT: Pointer 81
+CHECK-NEXT: Constant        82
+CHECK-NEXT: Variable        83
diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll
index 6f64bab888f6b..7b476f60a07b3 100644
--- a/llvm/test/tools/llvm-ir2vec/triplets.ll
+++ b/llvm/test/tools/llvm-ir2vec/triplets.ll
@@ -25,41 +25,41 @@ entry:
 }
 
 ; TRIPLETS: MAX_RELATION=3
-; TRIPLETS-NEXT: 12      80      0
-; TRIPLETS-NEXT: 12      92      2
-; TRIPLETS-NEXT: 12      92      3
+; TRIPLETS-NEXT: 12      74      0
+; TRIPLETS-NEXT: 12      83      2
+; TRIPLETS-NEXT: 12      83      3
 ; TRIPLETS-NEXT: 12      0       1
-; TRIPLETS-NEXT: 0       75      0
-; TRIPLETS-NEXT: 0       92      2
-; TRIPLETS-NEXT: 16      80      0
-; TRIPLETS-NEXT: 16      92      2
-; TRIPLETS-NEXT: 16      92      3
+; TRIPLETS-NEXT: 0       69      0
+; TRIPLETS-NEXT: 0       83      2
+; TRIPLETS-NEXT: 16      74      0
+; TRIPLETS-NEXT: 16      83      2
+; TRIPLETS-NEXT: 16      83      3
 ; TRIPLETS-NEXT: 16      0       1
-; TRIPLETS-NEXT: 0       75      0
-; TRIPLETS-NEXT: 0       92      2
-; TRIPLETS-NEXT: 30      82      0
-; TRIPLETS-NEXT: 30      91      2
+; TRIPLETS-NEXT: 0       69      0
+; TRIPLETS-NEXT: 0       83      2
+; TRIPLETS-NEXT: 30      76      0
+; TRIPLETS-NEXT: 30      82      2
 ; TRIPLETS-NEXT: 30      30      1
-; TRIPLETS-NEXT: 30      82      0
-; TRIPLETS-NEXT: 30      91      2
+; TRIPLETS-NEXT: 30      76      0
+; TRIPLETS-NEXT: 30      82      2
 ; TRIPLETS-NEXT: 30      32      1
-; TRIPLETS-NEXT: 32      75      0
-; TRIPLETS-NEXT: 32      92      2
-; TRIPLETS-NEXT: 32      90      3
+; TRIPLETS-NEXT: 32      69      0
+; TRIPLETS-NEXT: 32      83      2
+; TRIPLETS-NEXT: 32      81      3
 ; TRIPLETS-NEXT: 32      32      1
-; TRIPLETS-NEXT: 32      75      0
-; TRIPLETS-NEXT: 32      92      2
-; TRIPLETS-NEXT: 32      90      3
+; TRIPLETS-NEXT: 32      69      0
+; TRIPLETS-NEXT: 32      83      2
+; TRIPLETS-NEXT: 32      81      3
 ; TRIPLETS-NEXT: 32      31      1
-; TRIPLETS-NEXT: 31      80      0
-; TRIPLETS-NEXT: 31      90      2
+; TRIPLETS-NEXT: 31      74      0
+; TRIPLETS-NEXT: 31      81      2
 ; TRIPLETS-NEXT: 31      31      1
-; TRIPLETS-NEXT: 31      80      0
-; TRIPLETS-NEXT: 31      90      2
+; TRIPLETS-NEXT: 31      74      0
+; TRIPLETS-NEXT: 31      81      2
 ; TRIPLETS-NEXT: 31      12      1
-; TRIPLETS-NEXT: 12      80      0
-; TRIPLETS-NEXT: 12      92      2
-; TRIPLETS-NEXT: 12      92      3
+; TRIPLETS-NEXT: 12      74      0
+; TRIPLETS-NEXT: 12      83      2
+; TRIPLETS-NEXT: 12      83      3
 ; TRIPLETS-NEXT: 12      0       1
-; TRIPLETS-NEXT: 0       75      0
-; TRIPLETS-NEXT: 0       92      2
+; TRIPLETS-NEXT: 0       69      0
+; TRIPLETS-NEXT: 0       83      2
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index c065aaeedd395..461ded77d9609 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
 
     for (const BasicBlock &BB : F) {
       for (const auto &I : BB.instructionsWithoutDebug()) {
-        unsigned Opcode = Vocabulary::getNumericID(I.getOpcode());
-        unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID());
+        unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
+        unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
         // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getNumericID(U.get());
+          unsigned OperandID = Vocabulary::getSlotIndex(U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
           OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
 
@@ -211,13 +211,7 @@ class IR2VecTool {
 
   /// Dump entity ID to string mappings
   static void generateEntityMappings(raw_ostream &OS) {
-    // FIXME: Currently, the generated entity mappings are not one-to-one;
-    // Multiple TypeIDs map to same string key (Like Half, BFloat, etc. map to
-    // FloatTy). This would hinder learning good seed embeddings.
-    // We should fix this in the future by ensuring unique string keys either by
-    // post-processing here without changing the mapping in ir2vec::Vocabulary,
-    // or by changing the Vocabulary generation logic to ensure unique keys.
-    auto EntityLen = Vocabulary::expectedSize();
+    auto EntityLen = Vocabulary::getCanonicalSize();
     OS << EntityLen << "\n";
     for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
       OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index f0c81e160ca15..9f5428758d64c 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -336,8 +336,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
   EXPECT_EQ(AddEmb.size(), 2u);
   EXPECT_EQ(RetEmb.size(), 2u);
 
-  EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 27.9)));
-  EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0)));
+  EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+  EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
 }
 
 TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
@@ -353,8 +353,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
   EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
   EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
 
-  EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9)));
-  EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6)));
+  EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
+  EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
 }
 
 TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
@@ -367,9 +367,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
   EXPECT_TRUE(BBMap.count(BB));
   EXPECT_EQ(BBMap.at(BB).size(), 2u);
 
-  // BB vector should be sum of add and ret: {27.9, 27.9} + {17.0, 17.0} =
-  // {44.9, 44.9}
-  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9)));
+  // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+  // {41.0, 41.0}
+  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
 }
 
 TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
@@ -382,9 +382,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
   EXPECT_TRUE(BBMap.count(BB));
   EXPECT_EQ(BBMap.at(BB).size(), 2u);
 
-  // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} =
-  // {63.5, 63.5}
-  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5)));
+  // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+  // {58.1, 58.1}
+  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,7 +394,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
   const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
-  EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9)));
+  EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
@@ -404,7 +404,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
   const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
-  EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5)));
+  EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
 }
 
 TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
@@ -415,8 +415,8 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
 
   EXPECT_EQ(FuncVec.size(), 2u);
 
-  // Function vector should match BB vector (only one BB): {44.9, 44.9}
-  EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9)));
+  // Function vector should match BB vector (only one BB): {41.0, 41.0}
+  EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 41.0)));
 }
 
 TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
@@ -426,24 +426,40 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
   const auto &FuncVec = Emb->getFunctionVector();
 
   EXPECT_EQ(FuncVec.size(), 2u);
-  // Function vector should match BB vector (only one BB): {63.5, 63.5}
-  EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5)));
+  // Function vector should match BB vector (only one BB): {58.1, 58.1}
+  EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1)));
 }
 
 static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
 static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
+static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
 static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
 
+// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
+// names and their canonical string keys.
+#define IR2VEC_HANDLE_TYPE_BIMAP(X)                                            \
+  X(VoidTyID, VoidTy, "VoidTy")                                                \
+  X(IntegerTyID, IntegerTy, "IntegerTy")                                       \
+  X(FloatTyID, FloatTy, "FloatTy")                                             \
+  X(PointerTyID, PointerTy, "PointerTy")                                       \
+  X(FunctionTyID, FunctionTy, "FunctionTy")                                    \
+  X(StructTyID, StructTy, "StructTy")                                          \
+  X(ArrayTyID, ArrayTy, "ArrayTy")                                             \
+  X(FixedVectorTyID, VectorTy, "VectorTy")                                     \
+  X(LabelTyID, LabelTy, "LabelTy")                                             \
+  X(TokenTyID, TokenTy, "TokenTy")                                             \
+  X(MetadataTyID, MetadataTy, "MetadataTy")
+
 TEST(IR2VecVocabularyTest, DummyVocabTest) {
   for (unsigned Dim = 1; Dim <= 10; ++Dim) {
     auto VocabVec = Vocabulary::createDummyVocabForTest(Dim);
-
+    auto VocabVecSize = VocabVec.size();
     // All embeddings should have the same dimension
     for (const auto &Emb : VocabVec)
       EXPECT_EQ(Emb.size(), Dim);
 
     // Should have the correct total number of embeddings
-    EXPECT_EQ(VocabVec.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+    EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
 
     auto ExpectedVocab = VocabVec;
 
@@ -454,7 +470,7 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
     Vocabulary Result = VocabAnalysis.run(TestMod, MAM);
     EXPECT_TRUE(Result.isValid());
     EXPECT_EQ(Result.getDimension(), Dim);
-    EXPECT_EQ(Result.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
+    EXPECT_EQ(Result.getCanonicalSize(), VocabVecSize);
 
     unsigned CurPos = 0;
     for (const auto &Entry : Result)
@@ -462,64 +478,68 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
   }
 }
 
-TEST(IR2VecVocabularyTest, NumericIDMap) {
-  // Test getNumericID for opcodes
-  EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
-  EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
-  EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
-
-  // Test getNumericID for Type IDs
-  EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
-            MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
-  EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
-            MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
-  EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
-            MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
-  EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
-            MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
-  EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
-            MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
-
-  // Test getNumericID for Value operands
+TEST(IR2VecVocabularyTest, SlotIdxMapping) {
+  // Test getSlotIndex for Opcodes
+#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)                                 \
+  EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE_SLOT
+
+  // Test getSlotIndex for Types
+#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr)                       \
+  EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok),                         \
+            MaxOpcodes + static_cast<unsigned>(                                \
+                             Vocabulary::CanonicalTypeID::CanonEnum));
+
+  IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_SLOT)
+
+#undef EXPECT_TYPE_SLOT
+
+  // Test getSlotIndex for Value operands
   LLVMContext Ctx;
   Module M("TestM", Ctx);
   FunctionType *FTy =
       FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
 
+#define EXPECTED_VOCAB_OPERAND_SLOT(X)                                         \
+  MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
   // Test Function operand
-  EXPECT_EQ(Vocabulary::getNumericID(F),
-            MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+  EXPECT_EQ(Vocabulary::getSlotIndex(F),
+            EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
 
   // Test Constant operand
   Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
-  EXPECT_EQ(Vocabulary::getNumericID(C),
-            MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+  EXPECT_EQ(Vocabulary::getSlotIndex(C),
+            EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
 
   // Test Pointer operand
   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
   AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
-  EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
-            MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+  EXPECT_EQ(Vocabulary::getSlotIndex(PtrVal),
+            EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
 
   // Test Variable operand (function argument)
   Argument *Arg = F->getArg(0);
-  EXPECT_EQ(Vocabulary::getNumericID(Arg),
-            MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+  EXPECT_EQ(Vocabulary::getSlotIndex(Arg),
+            EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
+#undef EXPECTED_VOCAB_OPERAND_SLOT
 }
 
 #if GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
   // Test invalid opcode IDs
-  EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
-  EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
 
   // Test invalid type IDs
-  EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+  EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
                "Invalid type ID");
   EXPECT_DEATH(
-      Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+      Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
       "Invalid type ID");
 }
 #endif // NDEBUG
@@ -529,18 +549,46 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
   EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
   EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
 
-  StringRef HalfTypeKey = Vocabulary::getStringKey(MaxOpcodes + 0);
-  StringRef FloatTypeKey = Vocabulary::getStringKey(MaxOpcodes + 2);
-  StringRef VoidTypeKey = Vocabulary::getStringKey(MaxOpcodes + 7);
-  StringRef IntTypeKey = Vocabulary::getStringKey(MaxOpcodes + 12);
-
-  EXPECT_EQ(HalfTypeKey, "FloatTy");
-  EXPECT_EQ(FloatTypeKey, "FloatTy");
-  EXPECT_EQ(VoidTypeKey, "VoidTy");
-  EXPECT_EQ(IntTypeKey, "IntegerTy");
-
-  StringRef FuncArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 0);
-  StringRef PtrArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 1);
+#define EXPECT_OPCODE(NUM, OPCODE, CLASS)                                      \
+  EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)),           \
+            Vocabulary::getVocabKeyForOpcode(NUM));
+#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+#undef EXPECT_OPCODE
+
+  // Verify CanonicalTypeID -> string mapping
+#define EXPECT_CANONICAL_TYPE_NAME(TypeIDTok, CanonEnum, CanonStr)             \
+  EXPECT_EQ(Vocabulary::getStringKey(                                          \
+                MaxOpcodes + static_cast<unsigned>(                            \
+                                 Vocabulary::CanonicalTypeID::CanonEnum)),     \
+            CanonStr);
+
+  IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_CANONICAL_TYPE_NAME)
+
+#undef EXPECT_CANONICAL_TYPE_NAME
+
+#define HANDLE_OPERAND_KINDS(X)                                                \
+  X(FunctionID, "Function")                                                    \
+  X(PointerID, "Pointer")                                                      \
+  X(ConstantID, "Constant")                                                    \
+  X(VariableID, "Variable")
+
+#define EXPECT_OPERAND_KIND(EnumName, Str)                                     \
+  EXPECT_EQ(Vocabulary::getStringKey(                                          \
+                MaxOpcodes + MaxCanonicalTypeIDs +                             \
+                static_cast<unsigned>(Vocabulary::OperandKind::EnumName)),     \
+            Str);
+
+  HANDLE_OPERAND_KINDS(EXPECT_OPERAND_KIND)
+
+#undef EXPECT_OPERAND_KIND
+#undef HANDLE_OPERAND_KINDS
+
+  StringRef FuncArgKey =
+      Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 0);
+  StringRef PtrArgKey =
+      Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
   EXPECT_EQ(FuncArgKey, "Function");
   EXPECT_EQ(PtrArgKey, "Pointer");
 }
@@ -578,39 +626,14 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
 #endif // GTEST_HAS_DEATH_TEST
 
 TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::VoidTyID)),
-            "VoidTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::IntegerTyID)),
-            "IntegerTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::FloatTyID)),
-            "FloatTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::PointerTyID)),
-            "PointerTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::FunctionTyID)),
-            "FunctionTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::StructTyID)),
-            "StructTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::ArrayTyID)),
-            "ArrayTy");
-  EXPECT_EQ(Vocabulary::getStringKey(
-                MaxOpcodes + static_cast<unsigned>(Type::FixedVectorTyID)),
-            "VectorTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::LabelTyID)),
-            "LabelTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::TokenTyID)),
-            "TokenTy");
-  EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
-                                     static_cast<unsigned>(Type::MetadataTyID)),
-            "MetadataTy");
+#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr)               \
+  EXPECT_EQ(                                                                   \
+      Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)),     \
+      CanonStr);
+
+  IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
+
+#undef EXPECT_TYPE_TO_CANONICAL
 }
 
 TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
    
    
More information about the llvm-branch-commits
mailing list