[llvm] [llvm-ir2vec] adding BB-embedding map API to ir2vec python bindings (PR #177172)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 3 23:53:46 PST 2026
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/177172
>From f10e3f14257611993118c6b950f3a0c1789d1912 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 18:20:16 +0530
Subject: [PATCH 1/7] Adding getFuncEmbMap functionality to ir2vec python
bindings
---
llvm/test/tools/llvm-ir2vec/Inputs/input.ll | 24 +++++++++++++
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 17 +++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 35 +++++++++++++++++--
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 35 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 +++++
5 files changed, 118 insertions(+), 2 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/Inputs/input.ll b/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
index 1ca881faa0141..93e77be51b8e9 100644
--- a/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
+++ b/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
@@ -3,3 +3,27 @@ entry:
%sum = add i32 %a, %b
ret i32 %sum
}
+
+define i32 @multiply(i32 %x, i32 %y) {
+entry:
+ %prod = mul i32 %x, %y
+ ret i32 %prod
+}
+
+define i32 @conditional(i32 %n) {
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %positive, label %negative
+
+positive:
+ %pos_val = add i32 %n, 10
+ br label %exit
+
+negative:
+ %neg_val = sub i32 %n, 10
+ br label %exit
+
+exit:
+ %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+ ret i32 %result
+}
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 4038277667a47..b84499786e939 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -12,5 +12,22 @@
print("SUCCESS: Tool initialized")
print(f"Tool type: {type(tool).__name__}")
+ # Test getFuncEmbMap
+ func_emb_map = tool.getFuncEmbMap()
+ print(f"Number of functions: {len(func_emb_map)}")
+
+ # Check that all three functions are present
+ expected_funcs = ["add", "multiply", "conditional"]
+ for func_name in expected_funcs:
+ if func_name in func_emb_map:
+ emb = func_emb_map[func_name]
+ print(f"Function '{func_name}': embedding shape = {emb.shape}")
+ else:
+ print(f"ERROR: Function '{func_name}' not found")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
+# CHECK: Number of functions: 3
+# CHECK: Function 'add': embedding shape =
+# CHECK: Function 'multiply': embedding shape =
+# CHECK: Function 'conditional': embedding shape =
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 530adee8e052e..346faf879c855 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -42,12 +42,18 @@ class PyIR2VecTool {
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<IR2VecTool> Tool;
+ IR2VecKind EmbKind;
public:
PyIR2VecTool(const std::string &Filename, const std::string &Mode,
const std::string &VocabPath) {
- if (Mode != "sym" && Mode != "fa")
+ EmbKind = [](const std::string &Mode) -> IR2VecKind {
+ if (Mode == "sym")
+ return IR2VecKind::Symbolic;
+ if (Mode == "fa")
+ return IR2VecKind::FlowAware;
throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
+ }(Mode);
if (VocabPath.empty())
throw nb::value_error("Empty Vocab Path not allowed");
@@ -62,6 +68,27 @@ class PyIR2VecTool {
.c_str());
}
}
+
+ nb::dict getFuncEmbMap() {
+ auto result = Tool->getFunctionEmbeddings(EmbKind);
+ nb::dict nb_result;
+
+ for (const auto &[func_ptr, embedding] : result) {
+ std::string func_name = func_ptr->getName().str();
+ auto data = embedding.getData();
+ size_t shape[1] = {data.size()};
+ double *data_ptr = new double[data.size()];
+ std::copy(data.data(), data.data() + data.size(), data_ptr);
+
+ auto nb_array = nb::ndarray<nb::numpy, double>(
+ data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
+ delete[] static_cast<double *>(p);
+ }));
+ nb_result[nb::str(func_name.c_str())] = nb_array;
+ }
+
+ return nb_result;
+ }
};
} // namespace
@@ -72,7 +99,11 @@ NB_MODULE(ir2vec, m) {
nb::class_<PyIR2VecTool>(m, "IR2VecTool")
.def(nb::init<const std::string &, const std::string &,
const std::string &>(),
- nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"));
+ nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+ .def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
+ "Generate function-level embeddings for all functions\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{function_name: embedding}");
m.def(
"initEmbedding",
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 190d9259e45b3..4e8589885e019 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -151,6 +151,41 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
+std::pair<const Function *, Embedding>
+IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ if (F.isDeclaration())
+ return {nullptr, Embedding()};
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb) {
+ return {nullptr, Embedding()};
+ }
+
+ auto FuncVec = Emb->getFunctionVector();
+
+ return {&F, std::move(FuncVec)};
+}
+
+FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ FuncEmbMap Result;
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ auto Emb = getFunctionEmbedding(F, Kind);
+ if (Emb.first != nullptr) {
+ Result.try_emplace(Emb.first, std::move(Emb.second));
+ }
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d9715b03c3082..d115d9a26ca90 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -16,6 +16,7 @@
#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MIRParser/MIRParser.h"
@@ -72,6 +73,7 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
+using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
namespace ir2vec {
@@ -112,6 +114,13 @@ class IR2VecTool {
/// Returns EntityList containing all entity strings
static EntityList collectEntityMappings();
+ // Get embedding for a single function
+ std::pair<const Function *, Embedding>
+ getFunctionEmbedding(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all functions in the module
+ FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From 90bb87e7f9943d9b5b384c9d2d65b63a8e5c2ed7 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:36:48 +0530
Subject: [PATCH 2/7] Changing unit-test structure for function embeddings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 26 +++++++++----------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index b84499786e939..a209a47cba42e 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -13,21 +13,21 @@
print(f"Tool type: {type(tool).__name__}")
# Test getFuncEmbMap
+ print("\n=== Function Embeddings ===")
func_emb_map = tool.getFuncEmbMap()
- print(f"Number of functions: {len(func_emb_map)}")
- # Check that all three functions are present
- expected_funcs = ["add", "multiply", "conditional"]
- for func_name in expected_funcs:
- if func_name in func_emb_map:
- emb = func_emb_map[func_name]
- print(f"Function '{func_name}': embedding shape = {emb.shape}")
- else:
- print(f"ERROR: Function '{func_name}' not found")
+ # Sorting the function names for fixed-ordered output
+ for func_name in sorted(func_emb_map.keys()):
+ emb = func_emb_map[func_name]
+ print(f"Function: {func_name}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
-# CHECK: Number of functions: 3
-# CHECK: Function 'add': embedding shape =
-# CHECK: Function 'multiply': embedding shape =
-# CHECK: Function 'conditional': embedding shape =
+# CHECK: === Function Embeddings ===
+# CHECK: Function: add
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: Function: conditional
+# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
+# CHECK: Function: multiply
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
>From 269953157808b29253deefbf9a01cb9073d027fc Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 3 Feb 2026 13:16:31 +0530
Subject: [PATCH 3/7] Nit commits, formatting fixups, naming conventions
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 35 ++++++++---------
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 40 ++++++++++++--------
llvm/tools/llvm-ir2vec/lib/Utils.h | 6 +--
3 files changed, 45 insertions(+), 36 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 346faf879c855..73bfa9511afcc 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -13,6 +13,7 @@
#include "llvm/Support/SourceMgr.h"
#include <nanobind/nanobind.h>
+#include <nanobind/ndarray.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unique_ptr.h>
@@ -42,12 +43,12 @@ class PyIR2VecTool {
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<IR2VecTool> Tool;
- IR2VecKind EmbKind;
+ IR2VecKind OutputEmbeddingMode;
public:
PyIR2VecTool(const std::string &Filename, const std::string &Mode,
const std::string &VocabPath) {
- EmbKind = [](const std::string &Mode) -> IR2VecKind {
+ OutputEmbeddingMode = [](const std::string &Mode) -> IR2VecKind {
if (Mode == "sym")
return IR2VecKind::Symbolic;
if (Mode == "fa")
@@ -70,24 +71,24 @@ class PyIR2VecTool {
}
nb::dict getFuncEmbMap() {
- auto result = Tool->getFunctionEmbeddings(EmbKind);
- nb::dict nb_result;
-
- for (const auto &[func_ptr, embedding] : result) {
- std::string func_name = func_ptr->getName().str();
- auto data = embedding.getData();
- size_t shape[1] = {data.size()};
- double *data_ptr = new double[data.size()];
- std::copy(data.data(), data.data() + data.size(), data_ptr);
-
- auto nb_array = nb::ndarray<nb::numpy, double>(
- data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
- delete[] static_cast<double *>(p);
+ auto ToolFuncEmbMap = Tool->getFunctionEmbeddingsMap(OutputEmbeddingMode);
+ nb::dict NBFuncEmbMap;
+
+ for (const auto &[FuncPtr, FuncEmb] : ToolFuncEmbMap) {
+ std::string FuncName = FuncPtr->getName().str();
+ auto Data = FuncEmb.getData();
+ size_t Shape[1] = {Data.size()};
+ double *DataPtr = new double[Data.size()];
+ std::copy(Data.data(), Data.data() + Data.size(), DataPtr);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
}));
- nb_result[nb::str(func_name.c_str())] = nb_array;
+ NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
}
- return nb_result;
+ return NBFuncEmbMap;
}
};
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4e8589885e019..6a772b6787f43 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -151,25 +151,32 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
-std::pair<const Function *, Embedding>
+std::optional<Embedding>
IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary not initialized properly.\n";
+ return std::nullopt;
+ }
if (F.isDeclaration())
- return {nullptr, Embedding()};
+ return std::nullopt;
auto Emb = Embedder::create(Kind, F, *Vocab);
if (!Emb) {
- return {nullptr, Embedding()};
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for function '" << F.getName() << "'.\n";
+ return std::nullopt;
}
-
- auto FuncVec = Emb->getFunctionVector();
-
- return {&F, std::move(FuncVec)};
+ return Emb->getFunctionVector();
}
-FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary not initialized properly.\n";
+ return {};
+ }
FuncEmbMap Result;
@@ -178,9 +185,8 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
continue;
auto Emb = getFunctionEmbedding(F, Kind);
- if (Emb.first != nullptr) {
- Result.try_emplace(Emb.first, std::move(Emb.second));
- }
+ if (Emb.has_value())
+ Result.try_emplace(&F, std::move(*Emb));
}
return Result;
@@ -188,7 +194,7 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab->isValid()) {
+ if (!Vocab || !Vocab->isValid()) {
WithColor::error(errs(), ToolName)
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
@@ -396,7 +402,8 @@ void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary not initialized properly.\n";
return;
}
@@ -415,7 +422,8 @@ void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary not initialized properly.\n";
return;
}
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d115d9a26ca90..da2332a06af29 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -115,11 +115,11 @@ class IR2VecTool {
static EntityList collectEntityMappings();
// Get embedding for a single function
- std::pair<const Function *, Embedding>
- getFunctionEmbedding(const Function &F, IR2VecKind Kind) const;
+ std::optional<Embedding> getFunctionEmbedding(const Function &F,
+ IR2VecKind Kind) const;
/// Get embeddings for all functions in the module
- FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+ FuncEmbMap getFunctionEmbeddingsMap(IR2VecKind Kind) const;
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From f876a9e55b882dce8afe55a1594bb51e8990bb6c Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 3 Feb 2026 15:14:39 +0530
Subject: [PATCH 4/7] Vocab errors should be fatal errors
---
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 33 +++++++++++-----------------
1 file changed, 13 insertions(+), 20 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 6a772b6787f43..ea81525ce99fe 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -153,11 +153,9 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
std::optional<Embedding>
IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary not initialized properly.\n";
- return std::nullopt;
- }
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(errc::invalid_argument,
+ "Failed to initialize IR2Vec vocabulary");
if (F.isDeclaration())
return std::nullopt;
@@ -172,11 +170,9 @@ IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
}
FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary not initialized properly.\n";
- return {};
- }
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(errc::invalid_argument,
+ "Failed to initialize IR2Vec vocabulary");
FuncEmbMap Result;
@@ -194,11 +190,9 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(errc::invalid_argument,
+ "Failed to initialize IR2Vec vocabulary");
for (const Function &F : M.getFunctionDefs())
writeEmbeddingsToStream(F, OS, Level);
@@ -206,11 +200,10 @@ void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
+ if (!Vocab || !Vocab->isValid())
+ return createStringError(errc::invalid_argument,
+ "Failed to initialize IR2Vec vocabulary");
+
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
>From 12b8545f14b40a51abdfc1b2cce21bacc93523b1 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 4 Feb 2026 13:17:08 +0530
Subject: [PATCH 5/7] Nit changes, cleaning up debug artifacts
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 14 +++--
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 63 +++++++++++---------
llvm/tools/llvm-ir2vec/lib/Utils.h | 6 +-
3 files changed, 46 insertions(+), 37 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 73bfa9511afcc..6da633c906d34 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -72,19 +72,23 @@ class PyIR2VecTool {
nb::dict getFuncEmbMap() {
auto ToolFuncEmbMap = Tool->getFunctionEmbeddingsMap(OutputEmbeddingMode);
+
+ if (!ToolFuncEmbMap)
+ throw nb::value_error(toString(ToolFuncEmbMap.takeError()).c_str());
+
nb::dict NBFuncEmbMap;
- for (const auto &[FuncPtr, FuncEmb] : ToolFuncEmbMap) {
+ for (const auto &[FuncPtr, FuncEmb] : *ToolFuncEmbMap) {
std::string FuncName = FuncPtr->getName().str();
auto Data = FuncEmb.getData();
- size_t Shape[1] = {Data.size()};
double *DataPtr = new double[Data.size()];
- std::copy(Data.data(), Data.data() + Data.size(), DataPtr);
+ std::copy(Data.begin(), Data.end(), DataPtr);
auto NbArray = nb::ndarray<nb::numpy, double>(
- DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *P) noexcept {
- delete[] static_cast<double *>(P);
+ DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *p) noexcept {
+ delete[] static_cast<double *>(p);
}));
+
NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
}
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index ea81525ce99fe..50b71d6c134d6 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -26,6 +26,7 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/CodeGen/MIR2Vec.h"
@@ -151,38 +152,40 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
-std::optional<Embedding>
-IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
+Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
+ IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
- return createStringError(errc::invalid_argument,
- "Failed to initialize IR2Vec vocabulary");
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
if (F.isDeclaration())
- return std::nullopt;
+ return createStringError(errc::invalid_argument,
+ "Function is a declaration.");
auto Emb = Embedder::create(Kind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function '" << F.getName() << "'.\n";
- return std::nullopt;
- }
+ if (!Emb)
+ return createStringError(errc::invalid_argument,
+ "Failed to create embedder for function '%s'.",
+ F.getName().str().c_str());
+
return Emb->getFunctionVector();
}
-FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
+Expected<FuncEmbMap>
+IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
- return createStringError(errc::invalid_argument,
- "Failed to initialize IR2Vec vocabulary");
+ return createStringError(
+ errc::invalid_argument,
+ "Vocabulary is not valid. IR2VecTool not initialized.");
FuncEmbMap Result;
- for (const Function &F : M) {
- if (F.isDeclaration())
- continue;
-
+ for (const Function &F : M.getFunctionDefs()) {
auto Emb = getFunctionEmbedding(F, Kind);
- if (Emb.has_value())
- Result.try_emplace(&F, std::move(*Emb));
+ if (!Emb)
+ return Emb.takeError();
+ Result.try_emplace(&F, std::move(*Emb));
}
return Result;
@@ -190,9 +193,11 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(errc::invalid_argument,
- "Failed to initialize IR2Vec vocabulary");
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
for (const Function &F : M.getFunctionDefs())
writeEmbeddingsToStream(F, OS, Level);
@@ -200,9 +205,11 @@ void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(errc::invalid_argument,
- "Failed to initialize IR2Vec vocabulary");
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
@@ -395,8 +402,7 @@ void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary not initialized properly.\n";
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
@@ -415,8 +421,7 @@ void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary not initialized properly.\n";
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index da2332a06af29..d535f5fd5bb74 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -115,11 +115,11 @@ class IR2VecTool {
static EntityList collectEntityMappings();
// Get embedding for a single function
- std::optional<Embedding> getFunctionEmbedding(const Function &F,
- IR2VecKind Kind) const;
+ Expected<Embedding> getFunctionEmbedding(const Function &F,
+ IR2VecKind Kind) const;
/// Get embeddings for all functions in the module
- FuncEmbMap getFunctionEmbeddingsMap(IR2VecKind Kind) const;
+ Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From 87989d40fd4858d378702860b01a23cfd311e867 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 4 Feb 2026 13:23:19 +0530
Subject: [PATCH 6/7] Nit commit
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 6da633c906d34..015a57e622758 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -79,7 +79,6 @@ class PyIR2VecTool {
nb::dict NBFuncEmbMap;
for (const auto &[FuncPtr, FuncEmb] : *ToolFuncEmbMap) {
- std::string FuncName = FuncPtr->getName().str();
auto Data = FuncEmb.getData();
double *DataPtr = new double[Data.size()];
std::copy(Data.begin(), Data.end(), DataPtr);
@@ -89,7 +88,7 @@ class PyIR2VecTool {
delete[] static_cast<double *>(p);
}));
- NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
+ NBFuncEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbArray;
}
return NBFuncEmbMap;
>From 5c296c1f89c93a3e4131ed93957277ef88bf4d0c Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:32:36 +0530
Subject: [PATCH 7/7] adding BB embedding map API to ir2vec python bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 23 ++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 35 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 6 ++++
3 files changed, 64 insertions(+)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index a209a47cba42e..1173a8b1ed2d9 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -22,6 +22,16 @@
print(f"Function: {func_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getBBEmbMap
+ print("\n=== Basic Block Embeddings ===")
+ bb_emb_list = tool.getBBEmbMap()
+
+ # Sorting by BB name for deterministic output
+ bb_sorted = sorted(bb_emb_list, key=lambda x: (x[0], tuple(x[1].tolist())))
+ for bb_name, emb in bb_sorted:
+ print(f"BB: {bb_name}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -31,3 +41,16 @@
# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
# CHECK: Function: multiply
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Basic Block Embeddings ===
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
+# CHECK: BB: exit
+# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
+# CHECK: BB: negative
+# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
+# CHECK: BB: positive
+# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 50b71d6c134d6..76fc85b1afe2c 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -191,6 +191,41 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
return Result;
}
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(const Function &F,
+ IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const BasicBlock &BB : F)
+ Result.try_emplace(&BB, Emb->getBBVector(BB));
+
+ return Result;
+}
+
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBEmbeddingsMap Result;
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ BBEmbeddingsMap FuncBBVecs = getBBEmbeddings(F, Kind);
+ Result.insert(FuncBBVecs.begin(), FuncBBVecs.end());
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d535f5fd5bb74..30d4a1a095ac1 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -121,6 +121,12 @@ class IR2VecTool {
/// Get embeddings for all functions in the module
Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
+ /// Get embeddings for all basic blocks in a function
+ BBEmbeddingsMap getBBEmbeddings(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all basic blocks in the module
+ BBEmbeddingsMap getBBEmbeddings(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
More information about the llvm-commits
mailing list