[llvm] Adding getFuncNames API to ir2vec python bindings (PR #180473)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Sun Feb 8 23:00:15 PST 2026
https://github.com/nishant-sachdeva created https://github.com/llvm/llvm-project/pull/180473
None
>From 4eca51ae104e4c4426b937e1a75e0f70a9ef9ad3 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:20:27 +0530
Subject: [PATCH 1/9] Adding BB Embeddings Map API - returns a nested map
indexed by functions
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 32 ++++++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 39 +++++++++++++++-
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 44 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 8 ++++
4 files changed, 122 insertions(+), 1 deletion(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index a0d61e4808292..6461b73738c7d 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -31,6 +31,25 @@
print(f"Function: {func_name}")
print(f" Embedding: {func_emb.tolist()}")
+ # Test getBBEmbMap
+ print("\n=== Basic Block Embeddings ===")
+ bb_emb_map = tool.getBBEmbMap()
+
+ # Sorting by function name, then BB name, then embedding values for deterministic output
+ bb_sorted = []
+ for func_name in sorted(bb_emb_map.keys()):
+ func_bb_map = bb_emb_map[func_name]
+ for bb_name in sorted(func_bb_map.keys()):
+ emb = func_bb_map[bb_name]
+ bb_sorted.append((bb_name, emb))
+
+ # Sort the flattened list by BB name, then embedding values
+ bb_sorted = sorted(bb_sorted, key=lambda x: (x[0], tuple(x[1].tolist())))
+
+ for bb_name, emb in bb_sorted:
+ print(f"BB: {bb_name}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -47,3 +66,16 @@
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
# CHECK: Function: conditional
# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
+# CHECK: === Basic Block Embeddings ===
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
+# CHECK: BB: exit
+# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
+# CHECK: BB: negative
+# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
+# CHECK: BB: positive
+# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 5032a053ce7b6..1210d7775186f 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -119,6 +119,37 @@ class PyIR2VecTool {
return NbArray;
}
+
+ nb::dict getBBEmbMap() {
+ auto ToolFuncBBEmbMap = Tool->getFuncBBEmbMap(OutputEmbeddingMode);
+
+ if (!ToolFuncBBEmbMap)
+ throw nb::value_error(toString(ToolFuncBBEmbMap.takeError()).c_str());
+
+ nb::dict NbFuncBBEmbMap;
+
+ for (const auto &[FuncPtr, BBMap] : *ToolFuncBBEmbMap) {
+ nb::dict NbFuncBBMap;
+
+ for (const auto &[BBPtr, BBEmb] : BBMap) {
+ auto BBEmbVec = BBEmb.getData();
+ double *NbBBEmbVec = new double[BBEmbVec.size()];
+ std::copy(BBEmbVec.begin(), BBEmbVec.end(), NbBBEmbVec);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbBBEmbVec, {BBEmbVec.size()},
+ nb::capsule(NbBBEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
+
+ NbFuncBBMap[nb::str(BBPtr->getName().str().c_str())] = NbArray;
+ }
+
+ NbFuncBBEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncBBMap;
+ }
+
+ return NbFuncBBEmbMap;
+ }
};
} // namespace
@@ -137,7 +168,13 @@ NB_MODULE(ir2vec, m) {
.def("getFuncEmb", &PyIR2VecTool::getFuncEmb, nb::arg("funcName"),
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
- "Returns: ndarray[float64] - Function embedding vector");
+ "Returns: ndarray[float64] - Function embedding vector")
+ .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
+ "Generate embeddings for all basic blocks in the module\n"
+ "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+ "dictionary mapping "
+ "function names to dictionaries of basic block names to embedding "
+ "vectors");
m.def(
"initEmbedding",
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 50b71d6c134d6..7d2c8d14985e0 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -191,6 +191,50 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
return Result;
}
+Expected<BBEmbeddingsMap> IR2VecTool::getBBEmbMap(const Function &F,
+ IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ BBEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return createStringError(errc::invalid_argument,
+ "Function is a declaration.");
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return createStringError(errc::invalid_argument,
+ "Failed to create embedder for function '%s'.",
+ F.getName().str().c_str());
+
+ for (const BasicBlock &BB : F)
+ Result.try_emplace(&BB, Emb->getBBVector(BB));
+
+ return Result;
+}
+
+Expected<FuncBBEmbMap> IR2VecTool::getFuncBBEmbMap(IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ FuncBBEmbMap Result;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ auto FuncBBVecs = getBBEmbMap(F, Kind);
+ if (!FuncBBVecs)
+ return FuncBBVecs.takeError();
+
+ Result.try_emplace(&F, std::move(*FuncBBVecs));
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d535f5fd5bb74..7797a7b2a77ed 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,6 +74,7 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
+using FuncBBEmbMap = DenseMap<const Function *, ir2vec::BBEmbeddingsMap>;
namespace ir2vec {
@@ -121,6 +122,13 @@ class IR2VecTool {
/// Get embeddings for all functions in the module
Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
+ /// Get embeddings for all basic blocks in a function
+ Expected<BBEmbeddingsMap> getBBEmbMap(const Function &F,
+ IR2VecKind Kind) const;
+
+ /// Get embeddings for all basic blocks in the module, organized by function
+ Expected<FuncBBEmbMap> getFuncBBEmbMap(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From 40c13aefc03609ca9e4fe826e22b154a204848b9 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:49:01 +0530
Subject: [PATCH 2/9] removing unnecessary sorting in Bb Map tests
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 17 +++++------------
1 file changed, 5 insertions(+), 12 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 6461b73738c7d..74714fb92683e 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -35,20 +35,13 @@
print("\n=== Basic Block Embeddings ===")
bb_emb_map = tool.getBBEmbMap()
- # Sorting by function name, then BB name, then embedding values for deterministic output
- bb_sorted = []
+ # Sorting by function name, then BB name for deterministic output
for func_name in sorted(bb_emb_map.keys()):
func_bb_map = bb_emb_map[func_name]
for bb_name in sorted(func_bb_map.keys()):
emb = func_bb_map[bb_name]
- bb_sorted.append((bb_name, emb))
-
- # Sort the flattened list by BB name, then embedding values
- bb_sorted = sorted(bb_sorted, key=lambda x: (x[0], tuple(x[1].tolist())))
-
- for bb_name, emb in bb_sorted:
- print(f"BB: {bb_name}")
- print(f" Embedding: {emb.tolist()}")
+ print(f"BB: {bb_name}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
@@ -70,8 +63,6 @@
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
# CHECK: BB: entry
-# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
-# CHECK: BB: entry
# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
# CHECK: BB: exit
# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
@@ -79,3 +70,5 @@
# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
# CHECK: BB: positive
# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
\ No newline at end of file
>From eafac1b98e7b68a275b14e8c918382efed633abe Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:49:45 +0530
Subject: [PATCH 3/9] nit commit - formatting fixup
---
llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 74714fb92683e..483f0b1478beb 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -71,4 +71,4 @@
# CHECK: BB: positive
# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
# CHECK: BB: entry
-# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
\ No newline at end of file
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
>From e2aa0b5ea2c0f7a0b8a403fe461f86613fe2c5a2 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 11:51:06 +0530
Subject: [PATCH 4/9] Changing BB Embeddings API to take funcName and return
limited output
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 43 ++++++++-------
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 54 +++++++++----------
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 23 +-------
llvm/tools/llvm-ir2vec/lib/Utils.h | 14 ++---
4 files changed, 56 insertions(+), 78 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 483f0b1478beb..502f8a2411aa8 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -33,15 +33,15 @@
# Test getBBEmbMap
print("\n=== Basic Block Embeddings ===")
- bb_emb_map = tool.getBBEmbMap()
- # Sorting by function name, then BB name for deterministic output
- for func_name in sorted(bb_emb_map.keys()):
- func_bb_map = bb_emb_map[func_name]
- for bb_name in sorted(func_bb_map.keys()):
- emb = func_bb_map[bb_name]
- print(f"BB: {bb_name}")
- print(f" Embedding: {emb.tolist()}")
+ # Test valid function names in sorted order
+ for func_name in sorted(["add", "multiply", "conditional"]):
+ bb_emb_map = tool.getBBEmbMap(func_name)
+ print(f"Function: {func_name}")
+ for bb_name in sorted(bb_emb_map.keys()):
+ emb = bb_emb_map[bb_name]
+ print(f" BB: {bb_name}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
@@ -60,15 +60,18 @@
# CHECK: Function: conditional
# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
# CHECK: === Basic Block Embeddings ===
-# CHECK: BB: entry
-# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
-# CHECK: BB: entry
-# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
-# CHECK: BB: exit
-# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
-# CHECK: BB: negative
-# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
-# CHECK: BB: positive
-# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
-# CHECK: BB: entry
-# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: Function: add
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: Function: conditional
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
+# CHECK: BB: exit
+# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
+# CHECK: BB: negative
+# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
+# CHECK: BB: positive
+# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
+# CHECK: Function: multiply
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 1210d7775186f..e334997b31394 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -120,35 +120,35 @@ class PyIR2VecTool {
return NbArray;
}
- nb::dict getBBEmbMap() {
- auto ToolFuncBBEmbMap = Tool->getFuncBBEmbMap(OutputEmbeddingMode);
+ nb::dict getBBEmbMap(const std::string &FuncName) {
+ const Function *F = M->getFunction(FuncName);
- if (!ToolFuncBBEmbMap)
- throw nb::value_error(toString(ToolFuncBBEmbMap.takeError()).c_str());
+ if (!F)
+ throw nb::value_error(
+ ("Function '" + FuncName + "' not found in module").c_str());
- nb::dict NbFuncBBEmbMap;
+ auto ToolBBEmbMap = Tool->getBBEmbeddingsMap(*F, OutputEmbeddingMode);
- for (const auto &[FuncPtr, BBMap] : *ToolFuncBBEmbMap) {
- nb::dict NbFuncBBMap;
+ if (!ToolBBEmbMap)
+ throw nb::value_error(toString(ToolBBEmbMap.takeError()).c_str());
- for (const auto &[BBPtr, BBEmb] : BBMap) {
- auto BBEmbVec = BBEmb.getData();
- double *NbBBEmbVec = new double[BBEmbVec.size()];
- std::copy(BBEmbVec.begin(), BBEmbVec.end(), NbBBEmbVec);
+ nb::dict NbBBEmbMap;
- auto NbArray = nb::ndarray<nb::numpy, double>(
- NbBBEmbVec, {BBEmbVec.size()},
- nb::capsule(NbBBEmbVec, [](void *P) noexcept {
- delete[] static_cast<double *>(P);
- }));
+ for (const auto &[BBPtr, BBEmb] : *ToolBBEmbMap) {
+ auto BBEmbVec = BBEmb.getData();
+ double *NbBBEmbVec = new double[BBEmbVec.size()];
+ std::copy(BBEmbVec.begin(), BBEmbVec.end(), NbBBEmbVec);
- NbFuncBBMap[nb::str(BBPtr->getName().str().c_str())] = NbArray;
- }
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbBBEmbVec, {BBEmbVec.size()},
+ nb::capsule(NbBBEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
- NbFuncBBEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncBBMap;
+ NbBBEmbMap[nb::str(BBPtr->getName().str().c_str())] = NbArray;
}
- return NbFuncBBEmbMap;
+ return NbBBEmbMap;
}
};
@@ -164,18 +164,16 @@ NB_MODULE(ir2vec, m) {
.def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
"Generate function-level embeddings for all functions\n"
"Returns: dict[str, ndarray[float64]] - "
- "{function_name: embedding}")
+ "{function_name: embedding vector}")
.def("getFuncEmb", &PyIR2VecTool::getFuncEmb, nb::arg("funcName"),
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: ndarray[float64] - Function embedding vector")
- .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
- "Generate embeddings for all basic blocks in the module\n"
- "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
- "dictionary mapping "
- "function names to dictionaries of basic block names to embedding "
- "vectors");
-
+ .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
+ "Generate embeddings for all basic blocks in a function\n"
+ "Args: funcName (str) - IR-Name of the function\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{basic_block_name: embedding vector}");
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 7d2c8d14985e0..70f501c116407 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -191,8 +191,8 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
return Result;
}
-Expected<BBEmbeddingsMap> IR2VecTool::getBBEmbMap(const Function &F,
- IR2VecKind Kind) const {
+Expected<BBEmbeddingsMap>
+IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
@@ -216,25 +216,6 @@ Expected<BBEmbeddingsMap> IR2VecTool::getBBEmbMap(const Function &F,
return Result;
}
-Expected<FuncBBEmbMap> IR2VecTool::getFuncBBEmbMap(IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(
- errc::invalid_argument,
- "Vocabulary is not valid. IR2VecTool not initialized.");
-
- FuncBBEmbMap Result;
-
- for (const Function &F : M.getFunctionDefs()) {
- auto FuncBBVecs = getBBEmbMap(F, Kind);
- if (!FuncBBVecs)
- return FuncBBVecs.takeError();
-
- Result.try_emplace(&F, std::move(*FuncBBVecs));
- }
-
- return Result;
-}
-
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 7797a7b2a77ed..04ec4a74b1e24 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,7 +74,6 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
-using FuncBBEmbMap = DenseMap<const Function *, ir2vec::BBEmbeddingsMap>;
namespace ir2vec {
@@ -115,6 +114,9 @@ class IR2VecTool {
/// Returns EntityList containing all entity strings
static EntityList collectEntityMappings();
+ /// Dump entity ID to string mappings
+ static void writeEntitiesToStream(raw_ostream &OS);
+
// Get embedding for a single function
Expected<Embedding> getFunctionEmbedding(const Function &F,
IR2VecKind Kind) const;
@@ -123,14 +125,8 @@ class IR2VecTool {
Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
/// Get embeddings for all basic blocks in a function
- Expected<BBEmbeddingsMap> getBBEmbMap(const Function &F,
- IR2VecKind Kind) const;
-
- /// Get embeddings for all basic blocks in the module, organized by function
- Expected<FuncBBEmbMap> getFuncBBEmbMap(IR2VecKind Kind) const;
-
- /// Dump entity ID to string mappings
- static void writeEntitiesToStream(raw_ostream &OS);
+ Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
+ IR2VecKind Kind) const;
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 14c5258c7a40b451da9515820a1fdd0d0c1839f8 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:45:57 +0530
Subject: [PATCH 5/9] Adding Inst Embeddings Map API to ir2vec python bindings
- returns a nested map indexed by functions
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 39 ++++++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 51 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 45 ++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 ++++
4 files changed, 144 insertions(+)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 502f8a2411aa8..78c322d6d9265 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -43,6 +43,22 @@
print(f" BB: {bb_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getInstEmbMap
+ print("\n=== Instruction Embeddings ===")
+ inst_emb_map = tool.getInstEmbMap()
+
+ # Sorting by function name, then instruction string for deterministic output
+ inst_sorted = []
+ for func_name in sorted(inst_emb_map.keys()):
+ func_inst_map = inst_emb_map[func_name]
+ for inst_str in sorted(func_inst_map.keys()):
+ emb = func_inst_map[inst_str]
+ inst_sorted.append((inst_str, emb))
+
+ for inst_str, emb in inst_sorted:
+ print(f"Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -75,3 +91,26 @@
# CHECK: Function: multiply
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e334997b31394..dc47c279313bd 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -150,6 +150,41 @@ class PyIR2VecTool {
return NbBBEmbMap;
}
+
+ nb::dict getInstEmbMap() {
+ auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+
+ if (!ToolInstEmbMap)
+ throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
+
+ nb::dict NbInstEmbMap;
+
+ for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
+ nb::dict NbFuncInstMap;
+
+ for (const auto &[InstPtr, InstEmb] : InstMap) {
+ auto InstEmbVec = InstEmb.getData();
+ double *NbInstEmbVec = new double[InstEmbVec.size()];
+ std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbInstEmbVec, {InstEmbVec.size()},
+ nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
+
+ std::string InstStr;
+ raw_string_ostream OS(InstStr);
+ InstPtr->print(OS);
+
+ NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
+ }
+
+ NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+ }
+
+ return NbInstEmbMap;
+ }
};
} // namespace
@@ -169,11 +204,27 @@ NB_MODULE(ir2vec, m) {
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: ndarray[float64] - Function embedding vector")
+<<<<<<< HEAD
.def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
"Generate embeddings for all basic blocks in a function\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: dict[str, ndarray[float64]] - "
"{basic_block_name: embedding vector}");
+=======
+ .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
+ "Generate embeddings for all basic blocks in the module\n"
+ "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+ "dictionary mapping "
+ "function names to dictionaries of basic block names to embedding "
+ "vectors")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
+ "Generate embeddings for all instructions in the module\n"
+ "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+ "dictionary mapping "
+ "function names to dictionaries of instruction strings to embedding "
+ "vectors");
+
+>>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 70f501c116407..4215a6073a179 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -216,6 +216,51 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
+Expected<InstEmbeddingsMap>
+IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ InstEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return createStringError(errc::invalid_argument,
+ "Function is a declaration.");
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return createStringError(errc::invalid_argument,
+ "Failed to create embedder for function '%s'.",
+ F.getName().str().c_str());
+
+ for (const Instruction &I : instructions(F))
+ Result.try_emplace(&I, Emb->getInstVector(I));
+
+ return Result;
+}
+
+Expected<FuncInstEmbMap>
+IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
+
+ FuncInstEmbMap Result;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ auto FuncInstVecs = getInstEmbMap(F, Kind);
+ if (!FuncInstVecs)
+ return FuncInstVecs.takeError();
+
+ Result.try_emplace(&F, std::move(*FuncInstVecs));
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 04ec4a74b1e24..806cb5406b122 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,6 +74,7 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
+using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
namespace ir2vec {
@@ -127,6 +128,14 @@ class IR2VecTool {
/// Get embeddings for all basic blocks in a function
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
+ /// Get embeddings for all instructions in a function
+ Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all instructions in the module, organized by function
+ Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
+
+ /// Dump entity ID to string mappings
+ static void writeEntitiesToStream(raw_ostream &OS);
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 652cec5b43c04a1f14f8bf63cb71a13c2f458b51 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:59:01 +0530
Subject: [PATCH 6/9] nit commit - formatting fixup
---
llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 3 +--
llvm/tools/llvm-ir2vec/lib/Utils.h | 3 ++-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 78c322d6d9265..52ad69c7e04d6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -54,7 +54,7 @@
for inst_str in sorted(func_inst_map.keys()):
emb = func_inst_map[inst_str]
inst_sorted.append((inst_str, emb))
-
+
for inst_str, emb in inst_sorted:
print(f"Inst: {inst_str}")
print(f" Embedding: {emb.tolist()}")
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4215a6073a179..739cfed4bb497 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -241,8 +241,7 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
-Expected<FuncInstEmbMap>
-IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 806cb5406b122..e01b4d1baa6ee 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,8 @@ class IR2VecTool {
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
- Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+ Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+ IR2VecKind Kind) const;
/// Get embeddings for all instructions in the module, organized by function
Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
>From 74d6bf76c1bd045be4b05bba090ee50eb285f02b Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:14:28 +0530
Subject: [PATCH 7/9] Changing Inst Embedding Map API to limit to a specific
function
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 69 +++++++++----------
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 66 ++++++++----------
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 21 +-----
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 +--
4 files changed, 64 insertions(+), 101 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 52ad69c7e04d6..f4906c33e3619 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -45,19 +45,15 @@
# Test getInstEmbMap
print("\n=== Instruction Embeddings ===")
- inst_emb_map = tool.getInstEmbMap()
-
- # Sorting by function name, then instruction string for deterministic output
- inst_sorted = []
- for func_name in sorted(inst_emb_map.keys()):
- func_inst_map = inst_emb_map[func_name]
- for inst_str in sorted(func_inst_map.keys()):
- emb = func_inst_map[inst_str]
- inst_sorted.append((inst_str, emb))
-
- for inst_str, emb in inst_sorted:
- print(f"Inst: {inst_str}")
- print(f" Embedding: {emb.tolist()}")
+
+ # Test valid function names in sorted order
+ for func_name in sorted(["add", "multiply", "conditional"]):
+ inst_emb_map = tool.getInstEmbMap(func_name)
+ print(f"Function: {func_name}")
+ for inst_str in sorted(inst_emb_map.keys()):
+ emb = inst_emb_map[inst_str]
+ print(f" Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
@@ -92,25 +88,28 @@
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
# CHECK: === Instruction Embeddings ===
-# CHECK: Inst: %sum = add i32 %a, %b
-# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: ret i32 %sum
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
-# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
-# CHECK: Inst: %neg_val = sub i32 %n, 10
-# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
-# CHECK: Inst: %pos_val = add i32 %n, 10
-# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
-# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
-# CHECK: Inst: br i1 %cmp, label %positive, label %negative
-# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: br label %exit
-# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: ret i32 %result
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %prod = mul i32 %x, %y
-# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
-# CHECK: Inst: ret i32 %prod
-# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: add
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: conditional
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: multiply
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index dc47c279313bd..e4ddaf9c14e5a 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -151,36 +151,36 @@ class PyIR2VecTool {
return NbBBEmbMap;
}
- nb::dict getInstEmbMap() {
- auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+ nb::dict getInstEmbMap(const std::string &FuncName) {
+ const Function *F = M->getFunction(FuncName);
+
+ if (!F)
+ throw nb::value_error(
+ ("Function '" + FuncName + "' not found in module").c_str());
+
+ auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
if (!ToolInstEmbMap)
throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
nb::dict NbInstEmbMap;
- for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
- nb::dict NbFuncInstMap;
-
- for (const auto &[InstPtr, InstEmb] : InstMap) {
- auto InstEmbVec = InstEmb.getData();
- double *NbInstEmbVec = new double[InstEmbVec.size()];
- std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+ for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
+ auto InstEmbVec = InstEmb.getData();
+ double *NbInstEmbVec = new double[InstEmbVec.size()];
+ std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
- auto NbArray = nb::ndarray<nb::numpy, double>(
- NbInstEmbVec, {InstEmbVec.size()},
- nb::capsule(NbInstEmbVec, [](void *P) noexcept {
- delete[] static_cast<double *>(P);
- }));
-
- std::string InstStr;
- raw_string_ostream OS(InstStr);
- InstPtr->print(OS);
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbInstEmbVec, {InstEmbVec.size()},
+ nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
- NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
- }
+ std::string InstStr;
+ raw_string_ostream OS(InstStr);
+ InstPtr->print(OS);
- NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+ NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
}
return NbInstEmbMap;
@@ -204,27 +204,17 @@ NB_MODULE(ir2vec, m) {
"Generate embedding for a single function by name\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: ndarray[float64] - Function embedding vector")
-<<<<<<< HEAD
.def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
"Generate embeddings for all basic blocks in a function\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: dict[str, ndarray[float64]] - "
- "{basic_block_name: embedding vector}");
-=======
- .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
- "Generate embeddings for all basic blocks in the module\n"
- "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
- "dictionary mapping "
- "function names to dictionaries of basic block names to embedding "
- "vectors")
- .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
- "Generate embeddings for all instructions in the module\n"
- "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
- "dictionary mapping "
- "function names to dictionaries of instruction strings to embedding "
- "vectors");
-
->>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
+ "{basic_block_name: embedding vector}")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
+ "Generate embeddings for all instructions in a function\n"
+ "Args: funcName (str) - IR-Name of the function\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{instruction_string: embedding_vector}");
+
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 739cfed4bb497..4d1223b6ab1a5 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -217,7 +217,7 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
}
Expected<InstEmbeddingsMap>
-IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
@@ -241,25 +241,6 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
return Result;
}
-Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(
- errc::invalid_argument,
- "Vocabulary is not valid. IR2VecTool not initialized.");
-
- FuncInstEmbMap Result;
-
- for (const Function &F : M.getFunctionDefs()) {
- auto FuncInstVecs = getInstEmbMap(F, Kind);
- if (!FuncInstVecs)
- return FuncInstVecs.takeError();
-
- Result.try_emplace(&F, std::move(*FuncInstVecs));
- }
-
- return Result;
-}
-
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index e01b4d1baa6ee..639f37e74466b 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,7 +74,6 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
-using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
namespace ir2vec {
@@ -129,15 +128,9 @@ class IR2VecTool {
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
- Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+ Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
- /// Get embeddings for all instructions in the module, organized by function
- Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
-
- /// Dump entity ID to string mappings
- static void writeEntitiesToStream(raw_ostream &OS);
-
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 94e821fb4b3bb5c143ff003da80a817bc9804b3d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:28:01 +0530
Subject: [PATCH 8/9] nit commit - formatting fixup
---
llvm/tools/llvm-ir2vec/lib/Utils.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 639f37e74466b..face23895cff6 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,7 @@ class IR2VecTool {
IR2VecKind Kind) const;
/// Get embeddings for all instructions in a function
Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
- IR2VecKind Kind) const;
+ IR2VecKind Kind) const;
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
>From 77c050ce7fdf3b4e0a672df112ce7c202bb4fe5a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:27:13 +0530
Subject: [PATCH 9/9] Adding getFuncNames API to ir2vec python bindings
---
.../tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 10 ++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 11 +++++++++++
2 files changed, 21 insertions(+)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index f4906c33e3619..d71baa0abef3d 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -12,6 +12,12 @@
print("SUCCESS: Tool initialized")
print(f"Tool type: {type(tool).__name__}")
+ # Test getFuncNames
+ print("\n=== Function Names ===")
+ func_names = tool.getFuncNames()
+ for func_name in sorted(func_names):
+ print(f"Function: {func_name}")
+
# Test getFuncEmbMap
print("\n=== Function Embeddings ===")
func_emb_map = tool.getFuncEmbMap()
@@ -57,6 +63,10 @@
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
+# CHECK: === Function Names ===
+# CHECK: Function: add
+# CHECK: Function: conditional
+# CHECK: Function: multiply
# CHECK: === Function Embeddings ===
# CHECK: Function: add
# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e4ddaf9c14e5a..df372aedb9b63 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -70,6 +70,14 @@ class PyIR2VecTool {
}
}
+ nb::list getFuncNames() {
+ nb::list NbFuncNames;
+ for (const Function &F : M->getFunctionDefs()) {
+ NbFuncNames.append(nb::str(F.getName().str().c_str()));
+ }
+ return NbFuncNames;
+ }
+
nb::dict getFuncEmbMap() {
auto ToolFuncEmbMap = Tool->getFunctionEmbeddingsMap(OutputEmbeddingMode);
@@ -196,6 +204,9 @@ NB_MODULE(ir2vec, m) {
.def(nb::init<const std::string &, const std::string &,
const std::string &>(),
nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+ .def("getFuncNames", &PyIR2VecTool::getFuncNames,
+ "Get list of all defined functions in the module\n"
+ "Returns: list[str] - Function names")
.def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
"Generate function-level embeddings for all functions\n"
"Returns: dict[str, ndarray[float64]] - "
More information about the llvm-commits
mailing list