[llvm] [llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (PR #180140)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 11 22:00:56 PST 2026
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/180140
>From 1ab726cfbb4059c2f42a40591d441ad2427c41e5 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:45:57 +0530
Subject: [PATCH 1/5] Adding Inst Embeddings Map API to ir2vec python bindings
- returns a nested map indexed by functions
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 39 ++++++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 51 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 45 ++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 ++++
4 files changed, 144 insertions(+)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 502f8a2411aa8..78c322d6d9265 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -43,6 +43,22 @@
print(f" BB: {bb_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getInstEmbMap
+ print("\n=== Instruction Embeddings ===")
+ inst_emb_map = tool.getInstEmbMap()
+
+ # Sorting by function name, then instruction string for deterministic output
+ inst_sorted = []
+ for func_name in sorted(inst_emb_map.keys()):
+ func_inst_map = inst_emb_map[func_name]
+ for inst_str in sorted(func_inst_map.keys()):
+ emb = func_inst_map[inst_str]
+ inst_sorted.append((inst_str, emb))
+
+ for inst_str, emb in inst_sorted:
+ print(f"Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -75,3 +91,26 @@
# CHECK: Function: multiply
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e334997b31394..dc47c279313bd 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -150,6 +150,41 @@ class PyIR2VecTool {
return NbBBEmbMap;
}
+
+ nb::dict getInstEmbMap() {
+ auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+
+ if (!ToolInstEmbMap)
+ throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
+
+ nb::dict NbInstEmbMap;
+
+ for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
+ nb::dict NbFuncInstMap;
+
+ for (const auto &[InstPtr, InstEmb] : InstMap) {
+ auto InstEmbVec = InstEmb.getData();
+ double *NbInstEmbVec = new double[InstEmbVec.size()];
+ std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbInstEmbVec, {InstEmbVec.size()},
+ nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
+
+ std::string InstStr;
+ raw_string_ostream OS(InstStr);
+ InstPtr->print(OS);
+
+ NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
+ }
+
+ NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+ }
+
+ return NbInstEmbMap;
+ }
};
} // namespace
@@ -169,11 +204,27 @@ NB_MODULE(ir2vec, m) {
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: ndarray[float64] - Function embedding vector")
+<<<<<<< HEAD
.def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
"Generate embeddings for all basic blocks in a function\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: dict[str, ndarray[float64]] - "
"{basic_block_name: embedding vector}");
+=======
+ .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
+ "Generate embeddings for all basic blocks in the module\n"
+ "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+ "dictionary mapping "
+ "function names to dictionaries of basic block names to embedding "
+ "vectors")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
+ "Generate embeddings for all instructions in the module\n"
+ "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+ "dictionary mapping "
+ "function names to dictionaries of instruction strings to embedding "
+ "vectors");
+
+>>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 70f501c116407..4215a6073a179 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -216,6 +216,51 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
+Expected<InstEmbeddingsMap>
+IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ InstEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return createStringError(errc::invalid_argument,
+ "Function is a declaration.");
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return createStringError(errc::invalid_argument,
+ "Failed to create embedder for function '%s'.",
+ F.getName().str().c_str());
+
+ for (const Instruction &I : instructions(F))
+ Result.try_emplace(&I, Emb->getInstVector(I));
+
+ return Result;
+}
+
+Expected<FuncInstEmbMap>
+IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ FuncInstEmbMap Result;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ auto FuncInstVecs = getInstEmbMap(F, Kind);
+ if (!FuncInstVecs)
+ return FuncInstVecs.takeError();
+
+ Result.try_emplace(&F, std::move(*FuncInstVecs));
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 04ec4a74b1e24..806cb5406b122 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,6 +74,7 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
+using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
namespace ir2vec {
@@ -127,6 +128,14 @@ class IR2VecTool {
/// Get embeddings for all basic blocks in a function
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
+ /// Get embeddings for all instructions in a function
+ Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all instructions in the module, organized by function
+ Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
+
+ /// Dump entity ID to string mappings
+ static void writeEntitiesToStream(raw_ostream &OS);
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 1288aadaacce0a4cccb14d26a48a14e901d02a53 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:59:01 +0530
Subject: [PATCH 2/5] nit commit - formatting fixup
---
llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 3 +--
llvm/tools/llvm-ir2vec/lib/Utils.h | 3 ++-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 78c322d6d9265..52ad69c7e04d6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -54,7 +54,7 @@
for inst_str in sorted(func_inst_map.keys()):
emb = func_inst_map[inst_str]
inst_sorted.append((inst_str, emb))
-
+
for inst_str, emb in inst_sorted:
print(f"Inst: {inst_str}")
print(f" Embedding: {emb.tolist()}")
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4215a6073a179..739cfed4bb497 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -241,8 +241,7 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
-Expected<FuncInstEmbMap>
-IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 806cb5406b122..e01b4d1baa6ee 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,8 @@ class IR2VecTool {
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
- Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+ Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+ IR2VecKind Kind) const;
/// Get embeddings for all instructions in the module, organized by function
Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
>From dabd7e3147ecd2dd6a86c4222fb2b2df9941c664 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:14:28 +0530
Subject: [PATCH 3/5] Changing Inst Embedding Map API to limit to a specific
function
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 69 +++++++++----------
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 66 ++++++++----------
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 21 +-----
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 +--
4 files changed, 64 insertions(+), 101 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 52ad69c7e04d6..f4906c33e3619 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -45,19 +45,15 @@
# Test getInstEmbMap
print("\n=== Instruction Embeddings ===")
- inst_emb_map = tool.getInstEmbMap()
-
- # Sorting by function name, then instruction string for deterministic output
- inst_sorted = []
- for func_name in sorted(inst_emb_map.keys()):
- func_inst_map = inst_emb_map[func_name]
- for inst_str in sorted(func_inst_map.keys()):
- emb = func_inst_map[inst_str]
- inst_sorted.append((inst_str, emb))
-
- for inst_str, emb in inst_sorted:
- print(f"Inst: {inst_str}")
- print(f" Embedding: {emb.tolist()}")
+
+ # Test valid function names in sorted order
+ for func_name in sorted(["add", "multiply", "conditional"]):
+ inst_emb_map = tool.getInstEmbMap(func_name)
+ print(f"Function: {func_name}")
+ for inst_str in sorted(inst_emb_map.keys()):
+ emb = inst_emb_map[inst_str]
+ print(f" Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
@@ -92,25 +88,28 @@
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
# CHECK: === Instruction Embeddings ===
-# CHECK: Inst: %sum = add i32 %a, %b
-# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: ret i32 %sum
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
-# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
-# CHECK: Inst: %neg_val = sub i32 %n, 10
-# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
-# CHECK: Inst: %pos_val = add i32 %n, 10
-# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
-# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
-# CHECK: Inst: br i1 %cmp, label %positive, label %negative
-# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: br label %exit
-# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: ret i32 %result
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %prod = mul i32 %x, %y
-# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
-# CHECK: Inst: ret i32 %prod
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: add
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: conditional
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: multiply
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index dc47c279313bd..e4ddaf9c14e5a 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -151,36 +151,36 @@ class PyIR2VecTool {
return NbBBEmbMap;
}
- nb::dict getInstEmbMap() {
- auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+ nb::dict getInstEmbMap(const std::string &FuncName) {
+ const Function *F = M->getFunction(FuncName);
+
+ if (!F)
+ throw nb::value_error(
+ ("Function '" + FuncName + "' not found in module").c_str());
+
+ auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
if (!ToolInstEmbMap)
throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
nb::dict NbInstEmbMap;
- for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
- nb::dict NbFuncInstMap;
-
- for (const auto &[InstPtr, InstEmb] : InstMap) {
- auto InstEmbVec = InstEmb.getData();
- double *NbInstEmbVec = new double[InstEmbVec.size()];
- std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+ for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
+ auto InstEmbVec = InstEmb.getData();
+ double *NbInstEmbVec = new double[InstEmbVec.size()];
+ std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
- auto NbArray = nb::ndarray<nb::numpy, double>(
- NbInstEmbVec, {InstEmbVec.size()},
- nb::capsule(NbInstEmbVec, [](void *P) noexcept {
- delete[] static_cast<double *>(P);
- }));
-
- std::string InstStr;
- raw_string_ostream OS(InstStr);
- InstPtr->print(OS);
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbInstEmbVec, {InstEmbVec.size()},
+ nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
- NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
- }
+ std::string InstStr;
+ raw_string_ostream OS(InstStr);
+ InstPtr->print(OS);
- NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+ NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
}
return NbInstEmbMap;
@@ -204,27 +204,17 @@ NB_MODULE(ir2vec, m) {
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: ndarray[float64] - Function embedding vector")
-<<<<<<< HEAD
.def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
"Generate embeddings for all basic blocks in a function\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: dict[str, ndarray[float64]] - "
- "{basic_block_name: embedding vector}");
-=======
- .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
- "Generate embeddings for all basic blocks in the module\n"
- "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
- "dictionary mapping "
- "function names to dictionaries of basic block names to embedding "
- "vectors")
- .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
- "Generate embeddings for all instructions in the module\n"
- "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
- "dictionary mapping "
- "function names to dictionaries of instruction strings to embedding "
- "vectors");
-
->>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
+ "{basic_block_name: embedding vector}")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
+ "Generate embeddings for all instructions in a function\n"
+ "Args: funcName (str) - IR-Name of the function\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{instruction_string: embedding_vector}");
+
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 739cfed4bb497..4d1223b6ab1a5 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -217,7 +217,7 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
}
Expected<InstEmbeddingsMap>
-IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
@@ -241,25 +241,6 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
-Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(
- errc::invalid_argument,
- "Vocabulary is not valid. IR2VecTool not initialized.");
-
- FuncInstEmbMap Result;
-
- for (const Function &F : M.getFunctionDefs()) {
- auto FuncInstVecs = getInstEmbMap(F, Kind);
- if (!FuncInstVecs)
- return FuncInstVecs.takeError();
-
- Result.try_emplace(&F, std::move(*FuncInstVecs));
- }
-
- return Result;
-}
-
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index e01b4d1baa6ee..639f37e74466b 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,7 +74,6 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
-using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
namespace ir2vec {
@@ -129,15 +128,9 @@ class IR2VecTool {
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
- Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+ Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
- /// Get embeddings for all instructions in the module, organized by function
- Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
-
- /// Dump entity ID to string mappings
- static void writeEntitiesToStream(raw_ostream &OS);
-
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 6faf88c6f9f7141508419475ce09db0d012bdea8 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:28:01 +0530
Subject: [PATCH 4/5] nit commit - formatting fixup
---
llvm/tools/llvm-ir2vec/lib/Utils.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 639f37e74466b..face23895cff6 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,7 @@ class IR2VecTool {
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
- IR2VecKind Kind) const;
+ IR2VecKind Kind) const;
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From bfa08ad70fca786aa048d2eb85e020b3ad50c303 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:31:05 +0530
Subject: [PATCH 5/5] nit commit - python file formatting fixup
---
llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index f4906c33e3619..bb29d33dc8ca6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -45,7 +45,7 @@
# Test getInstEmbMap
print("\n=== Instruction Embeddings ===")
-
+
# Test valid function names in sorted order
for func_name in sorted(["add", "multiply", "conditional"]):
inst_emb_map = tool.getInstEmbMap(func_name)
More information about the llvm-commits
mailing list