[llvm] [llvm-ir2vec] Decoupling Vocab loading from initEmbedding (PR #190507)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Apr 5 00:39:27 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlgo
Author: Nishant Sachdeva (nishant-sachdeva)
<details>
<summary>Changes</summary>
This has been done in order to save time during entire dataset processing. vocab loading should only happen once.
@<!-- -->svkeerthy
---
Full diff: https://github.com/llvm/llvm-project/pull/190507.diff
10 Files Affected:
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py (+4-1)
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py (+4-1)
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py (+4-1)
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py (+4-1)
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py (+4-1)
- (modified) llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py (+105-31)
- (modified) llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp (+50-25)
- (modified) llvm/tools/llvm-ir2vec/lib/Utils.cpp (+8-4)
- (modified) llvm/tools/llvm-ir2vec/lib/Utils.h (+9-3)
- (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+4-2)
``````````diff
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
index 333feadc6c932..963e0d0adeca5 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
@@ -6,7 +6,10 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
# Success case
bb_map = tool.getBBEmbMap("conditional")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
index 61b9464c89757..b7ac30f689da6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
@@ -6,7 +6,10 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
# Success case
emb = tool.getFuncEmb("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
index 7600d5e4a2986..6bc3adbca80ac 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
@@ -6,7 +6,10 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
# Success case
emb_map = tool.getFuncEmbMap()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
index 432d80e97edb9..f4420bae6caac 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
@@ -6,7 +6,10 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
# Success case
func_names = tool.getFuncNames()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
index 3157ae34cfd3c..a76d23ed84146 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
@@ -6,7 +6,10 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+vocab = ir2vec.loadVocab(vocab_path)
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
# Success case
inst_map = tool.getInstEmbMap("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
index c19935a0c6b7d..f7e37e7e1bdd5 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
@@ -2,61 +2,135 @@
import sys
import ir2vec
+import tempfile
+import os
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-# Success case
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
-print(f"SUCCESS: {type(tool).__name__}")
-# CHECK: SUCCESS: IR2VecTool
+# ============================================================
+# loadVocab tests
+# ============================================================
-# Error: Invalid mode
-try:
- ir2vec.initEmbedding(filename=ll_file, mode="invalid", vocabPath=vocab_path)
-except ValueError:
- print("ERROR: Invalid mode")
-# CHECK: ERROR: Invalid mode
+# Success: Load a valid vocabulary
+vocab = ir2vec.loadVocab(vocab_path)
+print(f"VOCAB: {type(vocab).__name__}")
+# CHECK: VOCAB: Vocab
# Error: Empty vocab path
try:
- ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="")
+ ir2vec.loadVocab("")
except ValueError:
print("ERROR: Empty vocab path")
# CHECK: ERROR: Empty vocab path
-# Error: Invalid file
+# Error: Non-existent vocab file
+try:
+ ir2vec.loadVocab("/nonexistent/path/bad.json")
+except ValueError:
+ print("ERROR: Invalid vocab path")
+# CHECK: ERROR: Invalid vocab path
+
+# Error: Malformed JSON vocab file
+with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
+ f.write("{ this is not valid json }")
+ bad_vocab = f.name
+try:
+ ir2vec.loadVocab(bad_vocab)
+except ValueError:
+ print("ERROR: Malformed vocab file")
+finally:
+ os.unlink(bad_vocab)
+# CHECK: ERROR: Malformed vocab file
+
+# Error: Wrong type for vocab path (not a string)
+try:
+ ir2vec.loadVocab(42)
+except TypeError:
+ print("ERROR: Invalid vocab path type")
+# CHECK: ERROR: Invalid vocab path type
+
+# ============================================================
+# initEmbedding tests
+# ============================================================
+
+# Success: Create embedding tool with valid inputs
+tool = ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab
+)
+print(f"SUCCESS: {type(tool).__name__}")
+# CHECK: SUCCESS: IR2VecTool
+
+# Success: Default mode (Symbolic) when mode is omitted
+tool_default = ir2vec.initEmbedding(filename=ll_file, vocab=vocab)
+print(f"DEFAULT MODE: {type(tool_default).__name__}")
+# CHECK: DEFAULT MODE: IR2VecTool
+
+# Error: Invalid mode (string instead of IR2VecKind enum)
try:
- ir2vec.initEmbedding(filename="/bad.ll", mode="sym", vocabPath=vocab_path)
+ ir2vec.initEmbedding(filename=ll_file, mode="invalid", vocab=vocab)
+except TypeError:
+ print("ERROR: Invalid mode type")
+# CHECK: ERROR: Invalid mode type
+
+# Error: Invalid mode (integer instead of IR2VecKind enum)
+try:
+ ir2vec.initEmbedding(filename=ll_file, mode=99, vocab=vocab)
+except TypeError:
+ print("ERROR: Invalid mode int")
+# CHECK: ERROR: Invalid mode int
+
+# Error: Non-existent IR file
+try:
+ ir2vec.initEmbedding(
+ filename="/nonexistent/bad.ll",
+ mode=ir2vec.IR2VecKind.Symbolic,
+ vocab=vocab,
+ )
except ValueError:
print("ERROR: Invalid file")
# CHECK: ERROR: Invalid file
# Error: Empty filename
try:
- ir2vec.initEmbedding(filename="", mode="sym", vocabPath=vocab_path)
+ ir2vec.initEmbedding(filename="", mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab)
except ValueError:
print("ERROR: Empty filename")
# CHECK: ERROR: Empty filename
-# Error: Invalid vocab file
+# Error: Wrong type for vocab (string instead of Vocab object)
try:
- ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="/bad.json")
-except ValueError:
- print("ERROR: Invalid vocab")
-# CHECK: ERROR: Invalid vocab
+ ir2vec.initEmbedding(
+ filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab="vocab.json"
+ )
+except TypeError:
+ print("ERROR: Vocab is string")
+# CHECK: ERROR: Vocab is string
-# Error: Malformed JSON vocab
-import tempfile
-import os
+# Error: Wrong type for vocab (None)
+try:
+ ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=None)
+except TypeError:
+ print("ERROR: Vocab is None")
+# CHECK: ERROR: Vocab is None
-with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
- f.write("{ this is not valid json }")
- bad_vocab = f.name
+# Error: Wrong type for vocab (integer)
try:
- ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=bad_vocab)
-except ValueError:
- print("ERROR: Invalid vocab file")
-finally:
- os.unlink(bad_vocab)
-# CHECK: ERROR: Invalid vocab file
+ ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocab=123)
+except TypeError:
+ print("ERROR: Vocab is int")
+# CHECK: ERROR: Vocab is int
+
+# Error: Missing vocab argument entirely
+try:
+ ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic)
+except TypeError:
+ print("ERROR: Vocab missing")
+# CHECK: ERROR: Vocab missing
+
+# Error: Wrong type for filename (not a string)
+try:
+ ir2vec.initEmbedding(filename=42, mode=ir2vec.IR2VecKind.Symbolic, vocab=vocab)
+except TypeError:
+ print("ERROR: Filename is int")
+# CHECK: ERROR: Filename is int
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 2f885b11519c7..43b87376e6977 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -38,6 +38,25 @@ std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
return M;
}
+class PyVocab {
+private:
+ std::shared_ptr<Vocabulary> Vocab;
+
+public:
+ PyVocab(const std::string &VocabPath) {
+ if (VocabPath.empty())
+ throw nb::value_error("Empty vocabulary path not allowed");
+ auto VocabOrErr = ir2vec::loadVocabulary(VocabPath);
+ if (!VocabOrErr)
+ throw nb::value_error(
+ ("Failed to load vocabulary: " + toString(VocabOrErr.takeError()))
+ .c_str());
+ Vocab = std::move(*VocabOrErr);
+ }
+
+ std::shared_ptr<Vocabulary> getVocab() const { return Vocab; }
+};
+
class PyIR2VecTool {
private:
std::unique_ptr<LLVMContext> Ctx;
@@ -46,28 +65,19 @@ class PyIR2VecTool {
IR2VecKind OutputEmbeddingMode;
public:
- PyIR2VecTool(const std::string &Filename, const std::string &Mode,
- const std::string &VocabPath) {
- OutputEmbeddingMode = [](const std::string &Mode) -> IR2VecKind {
- if (Mode == "sym")
- return IR2VecKind::Symbolic;
- if (Mode == "fa")
- return IR2VecKind::FlowAware;
- throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
- }(Mode);
+ PyIR2VecTool(const std::string &Filename, IR2VecKind Mode, PyVocab &Vocab) {
+ if (!Vocab.getVocab())
+ throw nb::value_error("Vocabulary object is not initialized");
- if (VocabPath.empty())
- throw nb::value_error("Empty Vocab Path not allowed");
+ if (Filename.empty())
+ throw nb::value_error("Empty filename not allowed");
+
+ OutputEmbeddingMode = Mode;
Ctx = std::make_unique<LLVMContext>();
M = getLLVMIR(Filename, *Ctx);
Tool = std::make_unique<IR2VecTool>(*M);
-
- if (auto Err = Tool->initializeVocabulary(VocabPath)) {
- throw nb::value_error(("Failed to initialize IR2Vec vocabulary: " +
- toString(std::move(Err)))
- .c_str());
- }
+ Tool->setVocabulary(Vocab.getVocab());
}
nb::list getFuncNames() {
@@ -200,10 +210,26 @@ class PyIR2VecTool {
NB_MODULE(ir2vec, m) {
m.doc() = std::string("Python bindings for ") + ToolName;
+ nb::enum_<IR2VecKind>(m, "IR2VecKind",
+ "Embedding mode for IR2Vec representations")
+ .value("Symbolic", IR2VecKind::Symbolic, "Symbolic encodings only")
+ .value("FlowAware", IR2VecKind::FlowAware,
+ "Flow-aware encodings (includes data/control flow)")
+ .export_values();
+
+ nb::class_<PyVocab>(m, "Vocab");
+
+ m.def(
+ "loadVocab",
+ [](const std::string &vocabPath) {
+ return std::make_unique<PyVocab>(vocabPath);
+ },
+ nb::arg("vocabPath"), "Load an IR2Vec vocabulary from a JSON file",
+ nb::rv_policy::take_ownership);
+
nb::class_<PyIR2VecTool>(m, "IR2VecTool")
- .def(nb::init<const std::string &, const std::string &,
- const std::string &>(),
- nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+ .def(nb::init<const std::string &, IR2VecKind, PyVocab &>(),
+ nb::arg("filename"), nb::arg("mode"), nb::arg("vocab"))
.def("getFuncNames", &PyIR2VecTool::getFuncNames,
"Get list of all defined functions in the module\n"
"Returns: list[str] - Function names")
@@ -228,10 +254,9 @@ NB_MODULE(ir2vec, m) {
m.def(
"initEmbedding",
- [](const std::string &filename, const std::string &mode,
- const std::string &vocabPath) {
- return std::make_unique<PyIR2VecTool>(filename, mode, vocabPath);
+ [](const std::string &filename, IR2VecKind mode, PyVocab &vocab) {
+ return std::make_unique<PyIR2VecTool>(filename, mode, vocab);
},
- nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocabPath"),
- nb::rv_policy::take_ownership);
+ nb::arg("filename"), nb::arg("mode") = IR2VecKind::Symbolic,
+ nb::arg("vocab"), nb::rv_policy::take_ownership);
}
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 1334aa1783583..0d78d93ef71b1 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -43,17 +43,21 @@ namespace llvm {
namespace ir2vec {
-Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
+Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath) {
auto VocabOrErr = Vocabulary::fromFile(VocabPath);
if (!VocabOrErr)
return VocabOrErr.takeError();
- Vocab = std::make_unique<Vocabulary>(std::move(*VocabOrErr));
+ auto V = std::make_shared<Vocabulary>(std::move(*VocabOrErr));
- if (!Vocab->isValid())
+ if (!V->isValid())
return createStringError(errc::invalid_argument,
"Failed to initialize IR2Vec vocabulary");
- return Error::success();
+ return V;
+}
+
+void IR2VecTool::setVocabulary(std::shared_ptr<Vocabulary> V) {
+ Vocab = std::move(V);
}
TripletResult IR2VecTool::generateTriplets(const Function &F) const {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index ae2f931a90cf9..be0f319297069 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -84,12 +84,15 @@ enum RelationType {
ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
};
+/// Load an IR2Vec vocabulary from a JSON file on disk.
+Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath);
+
/// Helper class for collecting IR triplets and generating embeddings
class IR2VecTool {
private:
Module &M;
ModuleAnalysisManager MAM;
- std::unique_ptr<Vocabulary> Vocab;
+ std::shared_ptr<Vocabulary> Vocab;
public:
explicit IR2VecTool(Module &M) : M(M) {}
@@ -98,8 +101,11 @@ class IR2VecTool {
Expected<std::unique_ptr<Embedder>>
createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
- /// Initialize the IR2Vec vocabulary from the specified file path.
- Error initializeVocabulary(StringRef VocabPath);
+ /// Load vocabulary from a shared pointer. This allows sharing the same
+ /// vocabulary instance across multiple IR2VecTool instances, which is useful
+ /// for generating embeddings for multiple functions without needing to reload
+ /// the vocabulary each time.
+ void setVocabulary(std::shared_ptr<Vocabulary> V);
/// Generate triplets for a single function
/// Returns a TripletResult with:
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index cf290ba931023..78a2e3f657705 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -161,8 +161,10 @@ static Error processModule(Module &M, raw_ostream &OS) {
"You may need to set it using --ir2vec-vocab-path");
}
- if (Error Err = Tool.initializeVocabulary(VocabFile))
- return Err;
+ auto VocabOrErr = ir2vec::loadVocabulary(VocabFile);
+ if (!VocabOrErr)
+ return VocabOrErr.takeError();
+ Tool.setVocabulary(std::move(*VocabOrErr));
if (!FunctionName.empty()) {
// Process single function
``````````
</details>
https://github.com/llvm/llvm-project/pull/190507
More information about the llvm-commits
mailing list