[llvm] [IR2Vec] Restructuring Vocabulary (PR #145119)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 1 13:54:33 PDT 2025


================
@@ -128,9 +129,73 @@ struct Embedding {
 
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
-// FIXME: Current the keys are strings. This can be changed to
-// use integers for cheaper lookups.
-using Vocab = std::map<std::string, Embedding>;
+
+/// Class for storing and accessing the IR2Vec vocabulary.
+/// Encapsulates all vocabulary-related constants, logic, and access methods.
+class Vocabulary {
+  friend class llvm::IR2VecVocabAnalysis;
+  using VocabVector = std::vector<ir2vec::Embedding>;
+  VocabVector Vocab;
+  bool Valid = false;
+
+/// Operand kinds supported by IR2Vec Vocabulary
+#define OPERAND_KINDS                                                          \
+  OPERAND_KIND(FunctionID, "Function")                                         \
+  OPERAND_KIND(PointerID, "Pointer")                                           \
+  OPERAND_KIND(ConstantID, "Constant")                                         \
+  OPERAND_KIND(VariableID, "Variable")
+
+  enum class OperandKind : unsigned {
+#define OPERAND_KIND(Name, Str) Name,
+    OPERAND_KINDS
+#undef OPERAND_KIND
+        MaxOperandKind
+  };
+
+#undef OPERAND_KINDS
+
+  /// Vocabulary layout constants
+#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
+#include "llvm/IR/Instruction.def"
+#undef LAST_OTHER_INST
+
+  static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1;
+  static constexpr unsigned MaxOperandKinds =
+      static_cast<unsigned>(OperandKind::MaxOperandKind);
+
+  /// Helper function to get vocabulary key for a given OperandKind
+  static StringRef getVocabKeyForOperandKind(OperandKind Kind);
+
+  /// Helper function to classify an operand into OperandKind
+  static OperandKind getOperandKind(const Value *Op);
+
+  /// Helper function to get vocabulary key for a given TypeID
+  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+
+public:
+  Vocabulary() = default;
+  Vocabulary(VocabVector &&Vocab);
+
+  bool isValid() const;
+  unsigned getDimension() const;
+  unsigned size() const;
+
+  const ir2vec::Embedding &at(unsigned Position) const;
----------------
svkeerthy wrote:

Currently its used in printer pass where we iterate based on index. The overloaded [] accessors cannot be used for indexing in this case.

https://github.com/llvm/llvm-project/pull/145119


More information about the llvm-commits mailing list