[llvm] [IR2Vec] Restructuring Vocabulary (PR #145119)

Aaron Ballman via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 15 10:52:49 PDT 2025


================
@@ -251,33 +208,144 @@ void SymbolicEmbedder::computeEmbeddings() const {
 }
 
 // ==----------------------------------------------------------------------===//
-// IR2VecVocabResult and IR2VecVocabAnalysis
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
-IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)), Valid(true) {}
+Vocabulary::Vocabulary(VocabVector &&Vocab)
+    : Vocab(std::move(Vocab)), Valid(true) {}
 
-const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const {
+bool Vocabulary::isValid() const {
+  return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+}
+
+size_t Vocabulary::size() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary;
+  return Vocab.size();
 }
 
-unsigned IR2VecVocabResult::getDimension() const {
+unsigned Vocabulary::getDimension() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary.begin()->second.size();
+  return Vocab[0].size();
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Vocab[Opcode - 1];
+}
+
+const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
+  assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
+  return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+}
+
+const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
+  OperandKind ArgKind = getOperandKind(Arg);
+  return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+}
+
+StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
+  switch (TypeID) {
+  case Type::VoidTyID:
+    return "VoidTy";
+  case Type::HalfTyID:
+  case Type::BFloatTyID:
+  case Type::FloatTyID:
+  case Type::DoubleTyID:
+  case Type::X86_FP80TyID:
+  case Type::FP128TyID:
+  case Type::PPC_FP128TyID:
+    return "FloatTy";
+  case Type::IntegerTyID:
+    return "IntegerTy";
+  case Type::FunctionTyID:
+    return "FunctionTy";
+  case Type::StructTyID:
+    return "StructTy";
+  case Type::ArrayTyID:
+    return "ArrayTy";
+  case Type::PointerTyID:
+  case Type::TypedPointerTyID:
+    return "PointerTy";
+  case Type::FixedVectorTyID:
+  case Type::ScalableVectorTyID:
+    return "VectorTy";
+  case Type::LabelTyID:
+    return "LabelTy";
+  case Type::TokenTyID:
+    return "TokenTy";
+  case Type::MetadataTyID:
+    return "MetadataTy";
+  case Type::X86_AMXTyID:
+  case Type::TargetExtTyID:
+  default:
+    return "UnknownTy";
+  }
+}
+
+StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
+  unsigned Index = static_cast<unsigned>(Kind);
+  assert(Index < MaxOperandKinds && "Invalid OperandKind");
+  return OperandKindNames[Index];
+}
+
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+  VocabVector DummyVocab;
+  float DummyVal = 0.1f;
+  // Create a dummy vocabulary with entries for all opcodes, types, and
+  // operand
+  for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
+                                Vocabulary::MaxOperandKinds)) {
+    DummyVocab.push_back(Embedding(Dim, DummyVal));
+    DummyVal += 0.1;
----------------
AaronBallman wrote:

Thank you, I appreciate it!

https://github.com/llvm/llvm-project/pull/145119


More information about the llvm-commits mailing list