[llvm] [llvm-ir2vec] adding BB-embedding map API to ir2vec python bindings (PR #177172)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 5 02:34:58 PST 2026
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/177172
>From 36d5a625aae2cb549648f09130345b651ae89bba Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 5 Feb 2026 16:03:01 +0530
Subject: [PATCH] Adding FuncEmb API to ir2vec python bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 16 ++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 31 ++++++++++++++++++-
2 files changed, 46 insertions(+), 1 deletion(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index a209a47cba42e..a0d61e4808292 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -22,6 +22,15 @@
print(f"Function: {func_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getFuncEmb for individual functions
+ print("\n=== Single Function Embeddings ===")
+
+ # Test valid function names
+ for func_name in ["add", "multiply", "conditional"]:
+ func_emb = tool.getFuncEmb(func_name)
+ print(f"Function: {func_name}")
+ print(f" Embedding: {func_emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -31,3 +40,10 @@
# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
# CHECK: Function: multiply
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Single Function Embeddings ===
+# CHECK: Function: add
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: Function: multiply
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: Function: conditional
+# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 08a0844b44eef..a7bed68a6703d 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -94,6 +94,30 @@ class PyIR2VecTool {
return NBFuncEmbMap;
}
+
+ nb::ndarray<nb::numpy, double> getFuncEmb(const std::string &FuncName) {
+ const Function *FuncPtr = M->getFunction(FuncName);
+
+ if (!FuncPtr)
+ throw nb::value_error(("Function '" + FuncName + "' not found in module").c_str());
+
+ auto ToolFuncEmb = Tool->getFunctionEmbedding(*FuncPtr, OutputEmbeddingMode);
+
+ if (!ToolFuncEmb)
+ throw nb::value_error(toString(ToolFuncEmb.takeError()).c_str());
+
+ auto FuncEmbVec = ToolFuncEmb->getData();
+ double *NBFuncEmbVec = new double[FuncEmbVec.size()];
+ std::copy(FuncEmbVec.begin(), FuncEmbVec.end(), NBFuncEmbVec);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NBFuncEmbVec, {FuncEmbVec.size()},
+ nb::capsule(NBFuncEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
+
+ return NbArray;
+ }
};
} // namespace
@@ -108,7 +132,12 @@ NB_MODULE(ir2vec, m) {
.def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
"Generate function-level embeddings for all functions\n"
"Returns: dict[str, ndarray[float64]] - "
- "{function_name: embedding}");
+ "{function_name: embedding}")
+ .def("getFuncEmb", &PyIR2VecTool::getFuncEmb,
+ nb::arg("funcName"),
+ "Generate embedding for a single function by name\n"
+ "Args: funcName (str) - IR-Name of the function\n"
+ "Returns: ndarray[float64] - Function embedding vector");
m.def(
"initEmbedding",
More information about the llvm-commits
mailing list