[llvm] [IR2Vec] Restructuring Vocabulary (PR #145119)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 10:44:11 PDT 2025


================
@@ -251,33 +208,160 @@ void SymbolicEmbedder::computeEmbeddings() const {
 }
 
 // ==----------------------------------------------------------------------===//
-// IR2VecVocabResult and IR2VecVocabAnalysis
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
-IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)), Valid(true) {}
+Vocabulary::Vocabulary(VocabVector &&Vocab)
+    : Vocab(std::move(Vocab)), Valid(true) {}
 
-const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const {
+bool Vocabulary::isValid() const {
+  return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+}
+
+size_t Vocabulary::size() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary;
+  return Vocab.size();
 }
 
-unsigned IR2VecVocabResult::getDimension() const {
+unsigned Vocabulary::getDimension() const {
   assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocabulary.begin()->second.size();
+  return Vocab[0].size();
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Vocab[Opcode - 1];
+}
+
+const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
+  assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
+  return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+}
+
+const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
+  OperandKind ArgKind = getOperandKind(Arg);
+  return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+}
+
+StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
+  switch (TypeID) {
+  case Type::VoidTyID:
+    return "VoidTy";
+  case Type::HalfTyID:
+  case Type::BFloatTyID:
+  case Type::FloatTyID:
+  case Type::DoubleTyID:
+  case Type::X86_FP80TyID:
+  case Type::FP128TyID:
+  case Type::PPC_FP128TyID:
+    return "FloatTy";
+  case Type::IntegerTyID:
+    return "IntegerTy";
+  case Type::FunctionTyID:
+    return "FunctionTy";
+  case Type::StructTyID:
+    return "StructTy";
+  case Type::ArrayTyID:
+    return "ArrayTy";
+  case Type::PointerTyID:
+  case Type::TypedPointerTyID:
+    return "PointerTy";
+  case Type::FixedVectorTyID:
+  case Type::ScalableVectorTyID:
+    return "VectorTy";
+  case Type::LabelTyID:
+    return "LabelTy";
+  case Type::TokenTyID:
+    return "TokenTy";
+  case Type::MetadataTyID:
+    return "MetadataTy";
+  case Type::X86_AMXTyID:
+  case Type::TargetExtTyID:
+  default:
+    return "UnknownTy";
+  }
+}
+
+// Operand kinds supported by IR2Vec - string mappings
+#define OPERAND_KINDS                                                          \
----------------
svkeerthy wrote:

You are right about the duplication. It looks bad! 

However, I'd not prefer to leave macro globally defined without undef-ing it in the header to avoid potential pollution in the files that include `IR2Vec.h`. So, refactored a bit to avoid duplication by removing X-macros. It does not seem to be necessary here. 

https://github.com/llvm/llvm-project/pull/145119


More information about the llvm-commits mailing list