[llvm] [llvm-ir2vec] Decoupling Vocab loading from initEmbedding (PR #190507)

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 5 00:39:27 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlgo

Author: Nishant Sachdeva  (nishant-sachdeva)

<details>
<summary>Changes</summary>

 This has been done in order to save time during entire dataset processing. vocab loading should only happen once. 

@<!-- -->svkeerthy 

---
Full diff: https://github.com/llvm/llvm-project/pull/190507.diff


10 Files Affected:

- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py (+4-1) 
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py (+4-1) 
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py (+4-1) 
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py (+4-1) 
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py (+4-1) 
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py (+105-31) 
- (modified) llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp (+50-25) 
- (modified) llvm/tools/llvm-ir2vec/lib/Utils.cpp (+8-4) 
- (modified) llvm/tools/llvm-ir2vec/lib/Utils.h (+9-3) 
- (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+4-2) 


``````````diff
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
index 333feadc6c932..963e0d0adeca5 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
@@ -6,7 +6,10 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
 
 # Success case
 bb_map = tool.getBBEmbMap("conditional")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
index 61b9464c89757..b7ac30f689da6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
@@ -6,7 +6,10 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
 
 # Success case
 emb = tool.getFuncEmb("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
index 7600d5e4a2986..6bc3adbca80ac 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
@@ -6,7 +6,10 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
 
 # Success case
 emb_map = tool.getFuncEmbMap()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
index 432d80e97edb9..f4420bae6caac 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
@@ -6,7 +6,10 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
 
 # Success case
 func_names = tool.getFuncNames()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
index 3157ae34cfd3c..a76d23ed84146 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
@@ -6,7 +6,10 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
 
 # Success case
 inst_map = tool.getInstEmbMap("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
index c19935a0c6b7d..f7e37e7e1bdd5 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
@@ -2,61 +2,135 @@
 
 import sys
 import ir2vec
+import tempfile
+import os
 
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-# Success case
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
-print(f"SUCCESS: {type(tool).__name__}")
-# CHECK: SUCCESS: IR2VecTool
+# ============================================================
+# loadVocab tests
+# ============================================================
 
-# Error: Invalid mode
-try:
-    ir2vec.initEmbedding(filename=ll_file, mode="invalid", vocabPath=vocab_path)
-except ValueError:
-    print("ERROR: Invalid mode")
-# CHECK: ERROR: Invalid mode
+# Success: Load a valid vocabulary
+vocab = ir2vec.loadVocab(vocab_path)
+print(f"VOCAB: {type(vocab).__name__}")
+# CHECK: VOCAB: Vocab
 
 # Error: Empty vocab path
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="")
+    ir2vec.loadVocab("")
 except ValueError:
     print("ERROR: Empty vocab path")
 # CHECK: ERROR: Empty vocab path
 
-# Error: Invalid file
+# Error: Non-existent vocab file
+try:
+    ir2vec.loadVocab("/nonexistent/path/bad.json")
+except ValueError:
+    print("ERROR: Invalid vocab path")
+# CHECK: ERROR: Invalid vocab path
+
+# Error: Malformed JSON vocab file
+with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
+    f.write("{ this is not valid json }")
+    bad_vocab = f.name
+try:
+    ir2vec.loadVocab(bad_vocab)
+except ValueError:
+    print("ERROR: Malformed vocab file")
+finally:
+    os.unlink(bad_vocab)
+# CHECK: ERROR: Malformed vocab file
+
+# Error: Wrong type for vocab path (not a string)
+try:
+    ir2vec.loadVocab(42)
+except TypeError:
+    print("ERROR: Invalid vocab path type")
+# CHECK: ERROR: Invalid vocab path type
+
+# ============================================================
+# initEmbedding tests
+# ============================================================
+
+# Success: Create embedding tool with valid inputs
+tool = ir2vec.initEmbedding(
+    filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
+print(f"SUCCESS: {type(tool).__name__}")
+# CHECK: SUCCESS: IR2VecTool
+
+# Success: Default mode (Symbolic) when mode is omitted
+tool_default = ir2vec.initEmbedding(filename=ll_file, vocab=vocab)
+print(f"DEFAULT MODE: {type(tool_default).__name__}")
+# CHECK: DEFAULT MODE: IR2VecTool
+
+# Error: Invalid mode (string instead of IR2VecKind enum)
 try:
-    ir2vec.initEmbedding(filename="/bad.ll", mode="sym", vocabPath=vocab_path)
+    ir2vec.initEmbedding(filename=ll_file, mode="invalid", vocab=vocab)
+except TypeError:
+    print("ERROR: Invalid mode type")
+# CHECK: ERROR: Invalid mode type
+
+# Error: Invalid mode (integer instead of IR2VecKind enum)
+try:
+    ir2vec.initEmbedding(filename=ll_file, mode=99, vocab=vocab)
+except TypeError:
+    print("ERROR: Invalid mode int")
+# CHECK: ERROR: Invalid mode int
+
+# Error: Non-existent IR file
+try:
+    ir2vec.initEmbedding(
+        filename="/nonexistent/bad.ll",
+        mode=ir2vec.IR2VecKind.Symbolic,
+        vocab=vocab,
+    )
 except ValueError:
     print("ERROR: Invalid file")
 # CHECK: ERROR: Invalid file
 
 # Error: Empty filename
 try:
-    ir2vec.initEmbedding(filename="", mode="sym", vocabPath=vocab_path)
+    ir2vec.initEmbedding(filename="", mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab)
 except ValueError:
     print("ERROR: Empty filename")
 # CHECK: ERROR: Empty filename
 
-# Error: Invalid vocab file
+# Error: Wrong type for vocab (string instead of Vocab object)
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="/bad.json")
-except ValueError:
-    print("ERROR: Invalid vocab")
-# CHECK: ERROR: Invalid vocab
+    ir2vec.initEmbedding(
+        filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab="vocab.json"
+    )
+except TypeError:
+    print("ERROR: Vocab is string")
+# CHECK: ERROR: Vocab is string
 
-# Error: Malformed JSON vocab
-import tempfile
-import os
+# Error: Wrong type for vocab (None)
+try:
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=None)
+except TypeError:
+    print("ERROR: Vocab is None")
+# CHECK: ERROR: Vocab is None
 
-with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
-    f.write("{ this is not valid json }")
-    bad_vocab = f.name
+# Error: Wrong type for vocab (integer)
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=bad_vocab)
-except ValueError:
-    print("ERROR: Invalid vocab file")
-finally:
-    os.unlink(bad_vocab)
-# CHECK: ERROR: Invalid vocab file
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=123)
+except TypeError:
+    print("ERROR: Vocab is int")
+# CHECK: ERROR: Vocab is int
+
+# Error: Missing vocab argument entirely
+try:
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic)
+except TypeError:
+    print("ERROR: Vocab missing")
+# CHECK: ERROR: Vocab missing
+
+# Error: Wrong type for filename (not a string)
+try:
+    ir2vec.initEmbedding(filename=42, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab)
+except TypeError:
+    print("ERROR: Filename is int")
+# CHECK: ERROR: Filename is int
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 2f885b11519c7..43b87376e6977 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -38,6 +38,25 @@ std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
   return M;
 }
 
+class PyVocab {
+private:
+  std::shared_ptr<Vocabulary> Vocab;
+
+public:
+  PyVocab(const std::string &VocabPath) {
+    if (VocabPath.empty())
+      throw nb::value_error("Empty vocabulary path not allowed");
+    auto VocabOrErr = ir2vec::loadVocabulary(VocabPath);
+    if (!VocabOrErr)
+      throw nb::value_error(
+          ("Failed to load vocabulary: " + toString(VocabOrErr.takeError()))
+              .c_str());
+    Vocab = std::move(*VocabOrErr);
+  }
+
+  std::shared_ptr<Vocabulary> getVocab() const { return Vocab; }
+};
+
 class PyIR2VecTool {
 private:
   std::unique_ptr<LLVMContext> Ctx;
@@ -46,28 +65,19 @@ class PyIR2VecTool {
   IR2VecKind OutputEmbeddingMode;
 
 public:
-  PyIR2VecTool(const std::string &Filename, const std::string &Mode,
-               const std::string &VocabPath) {
-    OutputEmbeddingMode = [](const std::string &Mode) -> IR2VecKind {
-      if (Mode == "sym")
-        return IR2VecKind::Symbolic;
-      if (Mode == "fa")
-        return IR2VecKind::FlowAware;
-      throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
-    }(Mode);
+  PyIR2VecTool(const std::string &Filename, IR2VecKind Mode, PyVocab &Vocab) {
+    if (!Vocab.getVocab())
+      throw nb::value_error("Vocabulary object is not initialized");
 
-    if (VocabPath.empty())
-      throw nb::value_error("Empty Vocab Path not allowed");
+    if (Filename.empty())
+      throw nb::value_error("Empty filename not allowed");
+
+    OutputEmbeddingMode = Mode;
 
     Ctx = std::make_unique<LLVMContext>();
     M = getLLVMIR(Filename, *Ctx);
     Tool = std::make_unique<IR2VecTool>(*M);
-
-    if (auto Err = Tool->initializeVocabulary(VocabPath)) {
-      throw nb::value_error(("Failed to initialize IR2Vec vocabulary: " +
-                             toString(std::move(Err)))
-                                .c_str());
-    }
+    Tool->setVocabulary(Vocab.getVocab());
   }
 
   nb::list getFuncNames() {
@@ -200,10 +210,26 @@ class PyIR2VecTool {
 NB_MODULE(ir2vec, m) {
   m.doc() = std::string("Python bindings for ") + ToolName;
 
+  nb::enum_<IR2VecKind>(m, "IR2VecKind",
+                        "Embedding mode for IR2Vec representations")
+      .value("Symbolic", IR2VecKind::Symbolic, "Symbolic encodings only")
+      .value("FlowAware", IR2VecKind::FlowAware,
+             "Flow-aware encodings (includes data/control flow)")
+      .export_values();
+
+  nb::class_<PyVocab>(m, "Vocab");
+
+  m.def(
+      "loadVocab",
+      [](const std::string &vocabPath) {
+        return std::make_unique<PyVocab>(vocabPath);
+      },
+      nb::arg("vocabPath"), "Load an IR2Vec vocabulary from a JSON file",
+      nb::rv_policy::take_ownership);
+
   nb::class_<PyIR2VecTool>(m, "IR2VecTool")
-      .def(nb::init<const std::string &, const std::string &,
-                    const std::string &>(),
-           nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+      .def(nb::init<const std::string &, IR2VecKind, PyVocab &>(),
+           nb::arg("filename"), nb::arg("mode"), nb::arg("vocab"))
       .def("getFuncNames", &PyIR2VecTool::getFuncNames,
            "Get list of all defined functions in the module\n"
            "Returns: list[str] - Function names")
@@ -228,10 +254,9 @@ NB_MODULE(ir2vec, m) {
 
   m.def(
       "initEmbedding",
-      [](const std::string &filename, const std::string &mode,
-         const std::string &vocabPath) {
-        return std::make_unique<PyIR2VecTool>(filename, mode, vocabPath);
+      [](const std::string &filename, IR2VecKind mode, PyVocab &vocab) {
+        return std::make_unique<PyIR2VecTool>(filename, mode, vocab);
       },
-      nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocabPath"),
-      nb::rv_policy::take_ownership);
+      nb::arg("filename"), nb::arg("mode") = IR2VecKind::Symbolic,
+      nb::arg("vocab"), nb::rv_policy::take_ownership);
 }
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 1334aa1783583..0d78d93ef71b1 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -43,17 +43,21 @@ namespace llvm {
 
 namespace ir2vec {
 
-Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
+Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath) {
   auto VocabOrErr = Vocabulary::fromFile(VocabPath);
   if (!VocabOrErr)
     return VocabOrErr.takeError();
 
-  Vocab = std::make_unique<Vocabulary>(std::move(*VocabOrErr));
+  auto V = std::make_shared<Vocabulary>(std::move(*VocabOrErr));
 
-  if (!Vocab->isValid())
+  if (!V->isValid())
     return createStringError(errc::invalid_argument,
                              "Failed to initialize IR2Vec vocabulary");
-  return Error::success();
+  return V;
+}
+
+void IR2VecTool::setVocabulary(std::shared_ptr<Vocabulary> V) {
+  Vocab = std::move(V);
 }
 
 TripletResult IR2VecTool::generateTriplets(const Function &F) const {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index ae2f931a90cf9..be0f319297069 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -84,12 +84,15 @@ enum RelationType {
   ArgRelation = 2   ///< Instruction to operand relationship (ArgRelation + N)
 };
 
+/// Load an IR2Vec vocabulary from a JSON file on disk.
+Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath);
+
 /// Helper class for collecting IR triplets and generating embeddings
 class IR2VecTool {
 private:
   Module &M;
   ModuleAnalysisManager MAM;
-  std::unique_ptr<Vocabulary> Vocab;
+  std::shared_ptr<Vocabulary> Vocab;
 
 public:
   explicit IR2VecTool(Module &M) : M(M) {}
@@ -98,8 +101,11 @@ class IR2VecTool {
   Expected<std::unique_ptr<Embedder>>
   createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
 
-  /// Initialize the IR2Vec vocabulary from the specified file path.
-  Error initializeVocabulary(StringRef VocabPath);
+  /// Load vocabulary from a shared pointer. This allows sharing the same
+  /// vocabulary instance across multiple IR2VecTool instances, which is useful
+  /// for generating embeddings for multiple functions without needing to reload
+  /// the vocabulary each time.
+  void setVocabulary(std::shared_ptr<Vocabulary> V);
 
   /// Generate triplets for a single function
   /// Returns a TripletResult with:
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index cf290ba931023..78a2e3f657705 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -161,8 +161,10 @@ static Error processModule(Module &M, raw_ostream &OS) {
           "You may need to set it using --ir2vec-vocab-path");
     }
 
-    if (Error Err = Tool.initializeVocabulary(VocabFile))
-      return Err;
+    auto VocabOrErr = ir2vec::loadVocabulary(VocabFile);
+    if (!VocabOrErr)
+      return VocabOrErr.takeError();
+    Tool.setVocabulary(std::move(*VocabOrErr));
 
     if (!FunctionName.empty()) {
       // Process single function

``````````

</details>


https://github.com/llvm/llvm-project/pull/190507


More information about the llvm-commits mailing list