[llvm] [llvm-ir2vec] adding BB-embedding map API to ir2vec python bindings (PR #177172)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 3 23:53:46 PST 2026


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/177172

>From f10e3f14257611993118c6b950f3a0c1789d1912 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 18:20:16 +0530
Subject: [PATCH 1/7] Adding getFuncEmbMap functionality to ir2vec python
 bindings

---
 llvm/test/tools/llvm-ir2vec/Inputs/input.ll   | 24 +++++++++++++
 .../llvm-ir2vec/bindings/ir2vec-bindings.py   | 17 +++++++++
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp  | 35 +++++++++++++++++--
 llvm/tools/llvm-ir2vec/lib/Utils.cpp          | 35 +++++++++++++++++++
 llvm/tools/llvm-ir2vec/lib/Utils.h            |  9 +++++
 5 files changed, 118 insertions(+), 2 deletions(-)

diff --git a/llvm/test/tools/llvm-ir2vec/Inputs/input.ll b/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
index 1ca881faa0141..93e77be51b8e9 100644
--- a/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
+++ b/llvm/test/tools/llvm-ir2vec/Inputs/input.ll
@@ -3,3 +3,27 @@ entry:
   %sum = add i32 %a, %b
   ret i32 %sum
 }
+
+define i32 @multiply(i32 %x, i32 %y) {
+entry:
+  %prod = mul i32 %x, %y
+  ret i32 %prod
+}
+
+define i32 @conditional(i32 %n) {
+entry:
+  %cmp = icmp sgt i32 %n, 0
+  br i1 %cmp, label %positive, label %negative
+
+positive:
+  %pos_val = add i32 %n, 10
+  br label %exit
+
+negative:
+  %neg_val = sub i32 %n, 10
+  br label %exit
+
+exit:
+  %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+  ret i32 %result
+}
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 4038277667a47..b84499786e939 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -12,5 +12,22 @@
     print("SUCCESS: Tool initialized")
     print(f"Tool type: {type(tool).__name__}")
 
+    # Test getFuncEmbMap
+    func_emb_map = tool.getFuncEmbMap()
+    print(f"Number of functions: {len(func_emb_map)}")
+
+    # Check that all three functions are present
+    expected_funcs = ["add", "multiply", "conditional"]
+    for func_name in expected_funcs:
+        if func_name in func_emb_map:
+            emb = func_emb_map[func_name]
+            print(f"Function '{func_name}': embedding shape = {emb.shape}")
+        else:
+            print(f"ERROR: Function '{func_name}' not found")
+
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
+# CHECK: Number of functions: 3
+# CHECK: Function 'add': embedding shape =
+# CHECK: Function 'multiply': embedding shape =
+# CHECK: Function 'conditional': embedding shape =
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 530adee8e052e..346faf879c855 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -42,12 +42,18 @@ class PyIR2VecTool {
   std::unique_ptr<LLVMContext> Ctx;
   std::unique_ptr<Module> M;
   std::unique_ptr<IR2VecTool> Tool;
+  IR2VecKind EmbKind;
 
 public:
   PyIR2VecTool(const std::string &Filename, const std::string &Mode,
                const std::string &VocabPath) {
-    if (Mode != "sym" && Mode != "fa")
+    EmbKind = [](const std::string &Mode) -> IR2VecKind {
+      if (Mode == "sym")
+        return IR2VecKind::Symbolic;
+      if (Mode == "fa")
+        return IR2VecKind::FlowAware;
       throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
+    }(Mode);
 
     if (VocabPath.empty())
       throw nb::value_error("Empty Vocab Path not allowed");
@@ -62,6 +68,27 @@ class PyIR2VecTool {
                                 .c_str());
     }
   }
+
+  nb::dict getFuncEmbMap() {
+    auto result = Tool->getFunctionEmbeddings(EmbKind);
+    nb::dict nb_result;
+
+    for (const auto &[func_ptr, embedding] : result) {
+      std::string func_name = func_ptr->getName().str();
+      auto data = embedding.getData();
+      size_t shape[1] = {data.size()};
+      double *data_ptr = new double[data.size()];
+      std::copy(data.data(), data.data() + data.size(), data_ptr);
+
+      auto nb_array = nb::ndarray<nb::numpy, double>(
+          data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
+            delete[] static_cast<double *>(p);
+          }));
+      nb_result[nb::str(func_name.c_str())] = nb_array;
+    }
+
+    return nb_result;
+  }
 };
 
 } // namespace
@@ -72,7 +99,11 @@ NB_MODULE(ir2vec, m) {
   nb::class_<PyIR2VecTool>(m, "IR2VecTool")
       .def(nb::init<const std::string &, const std::string &,
                     const std::string &>(),
-           nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"));
+           nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+      .def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
+           "Generate function-level embeddings for all functions\n"
+           "Returns: dict[str, ndarray[float64]] - "
+           "{function_name: embedding}");
 
   m.def(
       "initEmbedding",
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 190d9259e45b3..4e8589885e019 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -151,6 +151,41 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
     OS << Entities[EntityID] << '\t' << EntityID << '\n';
 }
 
+std::pair<const Function *, Embedding>
+IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  if (F.isDeclaration())
+    return {nullptr, Embedding()};
+
+  auto Emb = Embedder::create(Kind, F, *Vocab);
+  if (!Emb) {
+    return {nullptr, Embedding()};
+  }
+
+  auto FuncVec = Emb->getFunctionVector();
+
+  return {&F, std::move(FuncVec)};
+}
+
+FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  FuncEmbMap Result;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    auto Emb = getFunctionEmbedding(F, Kind);
+    if (Emb.first != nullptr) {
+      Result.try_emplace(Emb.first, std::move(Emb.second));
+    }
+  }
+
+  return Result;
+}
+
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
   if (!Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d9715b03c3082..d115d9a26ca90 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -16,6 +16,7 @@
 #define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/CodeGen/MIR2Vec.h"
 #include "llvm/CodeGen/MIRParser/MIRParser.h"
@@ -72,6 +73,7 @@ struct TripletResult {
 
 /// Entity mappings: [entity_name]
 using EntityList = std::vector<std::string>;
+using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
 
 namespace ir2vec {
 
@@ -112,6 +114,13 @@ class IR2VecTool {
   /// Returns EntityList containing all entity strings
   static EntityList collectEntityMappings();
 
+  // Get embedding for a single function
+  std::pair<const Function *, Embedding>
+  getFunctionEmbedding(const Function &F, IR2VecKind Kind) const;
+
+  /// Get embeddings for all functions in the module
+  FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+
   /// Dump entity ID to string mappings
   static void writeEntitiesToStream(raw_ostream &OS);
 

>From 90bb87e7f9943d9b5b384c9d2d65b63a8e5c2ed7 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:36:48 +0530
Subject: [PATCH 2/7] Changing unit-test structure for function embeddings

---
 .../llvm-ir2vec/bindings/ir2vec-bindings.py   | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index b84499786e939..a209a47cba42e 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -13,21 +13,21 @@
     print(f"Tool type: {type(tool).__name__}")
 
     # Test getFuncEmbMap
+    print("\n=== Function Embeddings ===")
     func_emb_map = tool.getFuncEmbMap()
-    print(f"Number of functions: {len(func_emb_map)}")
 
-    # Check that all three functions are present
-    expected_funcs = ["add", "multiply", "conditional"]
-    for func_name in expected_funcs:
-        if func_name in func_emb_map:
-            emb = func_emb_map[func_name]
-            print(f"Function '{func_name}': embedding shape = {emb.shape}")
-        else:
-            print(f"ERROR: Function '{func_name}' not found")
+    # Sorting the function names for fixed-ordered output
+    for func_name in sorted(func_emb_map.keys()):
+        emb = func_emb_map[func_name]
+        print(f"Function: {func_name}")
+        print(f"  Embedding: {emb.tolist()}")
 
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
-# CHECK: Number of functions: 3
-# CHECK: Function 'add': embedding shape =
-# CHECK: Function 'multiply': embedding shape =
-# CHECK: Function 'conditional': embedding shape =
+# CHECK: === Function Embeddings ===
+# CHECK: Function: add
+# CHECK-NEXT:   Embedding: [38.0, 40.0, 42.0]
+# CHECK: Function: conditional
+# CHECK-NEXT:   Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
+# CHECK: Function: multiply
+# CHECK-NEXT:   Embedding: [50.0, 52.0, 54.0]

>From 269953157808b29253deefbf9a01cb9073d027fc Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 3 Feb 2026 13:16:31 +0530
Subject: [PATCH 3/7] Nit commits, formatting fixups, naming conventions

---
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 35 ++++++++---------
 llvm/tools/llvm-ir2vec/lib/Utils.cpp         | 40 ++++++++++++--------
 llvm/tools/llvm-ir2vec/lib/Utils.h           |  6 +--
 3 files changed, 45 insertions(+), 36 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 346faf879c855..73bfa9511afcc 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -13,6 +13,7 @@
 #include "llvm/Support/SourceMgr.h"
 
 #include <nanobind/nanobind.h>
+#include <nanobind/ndarray.h>
 #include <nanobind/stl/string.h>
 #include <nanobind/stl/unique_ptr.h>
 
@@ -42,12 +43,12 @@ class PyIR2VecTool {
   std::unique_ptr<LLVMContext> Ctx;
   std::unique_ptr<Module> M;
   std::unique_ptr<IR2VecTool> Tool;
-  IR2VecKind EmbKind;
+  IR2VecKind OutputEmbeddingMode;
 
 public:
   PyIR2VecTool(const std::string &Filename, const std::string &Mode,
                const std::string &VocabPath) {
-    EmbKind = [](const std::string &Mode) -> IR2VecKind {
+    OutputEmbeddingMode = [](const std::string &Mode) -> IR2VecKind {
       if (Mode == "sym")
         return IR2VecKind::Symbolic;
       if (Mode == "fa")
@@ -70,24 +71,24 @@ class PyIR2VecTool {
   }
 
   nb::dict getFuncEmbMap() {
-    auto result = Tool->getFunctionEmbeddings(EmbKind);
-    nb::dict nb_result;
-
-    for (const auto &[func_ptr, embedding] : result) {
-      std::string func_name = func_ptr->getName().str();
-      auto data = embedding.getData();
-      size_t shape[1] = {data.size()};
-      double *data_ptr = new double[data.size()];
-      std::copy(data.data(), data.data() + data.size(), data_ptr);
-
-      auto nb_array = nb::ndarray<nb::numpy, double>(
-          data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
-            delete[] static_cast<double *>(p);
+    auto ToolFuncEmbMap = Tool->getFunctionEmbeddingsMap(OutputEmbeddingMode);
+    nb::dict NBFuncEmbMap;
+
+    for (const auto &[FuncPtr, FuncEmb] : ToolFuncEmbMap) {
+      std::string FuncName = FuncPtr->getName().str();
+      auto Data = FuncEmb.getData();
+      size_t Shape[1] = {Data.size()};
+      double *DataPtr = new double[Data.size()];
+      std::copy(Data.data(), Data.data() + Data.size(), DataPtr);
+
+      auto NbArray = nb::ndarray<nb::numpy, double>(
+          DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *P) noexcept {
+            delete[] static_cast<double *>(P);
           }));
-      nb_result[nb::str(func_name.c_str())] = nb_array;
+      NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
     }
 
-    return nb_result;
+    return NBFuncEmbMap;
   }
 };
 
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4e8589885e019..6a772b6787f43 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -151,25 +151,32 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
     OS << Entities[EntityID] << '\t' << EntityID << '\n';
 }
 
-std::pair<const Function *, Embedding>
+std::optional<Embedding>
 IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
-  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary not initialized properly.\n";
+    return std::nullopt;
+  }
 
   if (F.isDeclaration())
-    return {nullptr, Embedding()};
+    return std::nullopt;
 
   auto Emb = Embedder::create(Kind, F, *Vocab);
   if (!Emb) {
-    return {nullptr, Embedding()};
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for function '" << F.getName() << "'.\n";
+    return std::nullopt;
   }
-
-  auto FuncVec = Emb->getFunctionVector();
-
-  return {&F, std::move(FuncVec)};
+  return Emb->getFunctionVector();
 }
 
-FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
-  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary not initialized properly.\n";
+    return {};
+  }
 
   FuncEmbMap Result;
 
@@ -178,9 +185,8 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
       continue;
 
     auto Emb = getFunctionEmbedding(F, Kind);
-    if (Emb.first != nullptr) {
-      Result.try_emplace(Emb.first, std::move(Emb.second));
-    }
+    if (Emb.has_value())
+      Result.try_emplace(&F, std::move(*Emb));
   }
 
   return Result;
@@ -188,7 +194,7 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
 
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab->isValid()) {
+  if (!Vocab || !Vocab->isValid()) {
     WithColor::error(errs(), ToolName)
         << "Vocabulary is not valid. IR2VecTool not initialized.\n";
     return;
@@ -396,7 +402,8 @@ void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
 void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
                                           EmbeddingLevel Level) const {
   if (!Vocab) {
-    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary not initialized properly.\n";
     return;
   }
 
@@ -415,7 +422,8 @@ void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
 void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
                                           EmbeddingLevel Level) const {
   if (!Vocab) {
-    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary not initialized properly.\n";
     return;
   }
 
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d115d9a26ca90..da2332a06af29 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -115,11 +115,11 @@ class IR2VecTool {
   static EntityList collectEntityMappings();
 
   // Get embedding for a single function
-  std::pair<const Function *, Embedding>
-  getFunctionEmbedding(const Function &F, IR2VecKind Kind) const;
+  std::optional<Embedding> getFunctionEmbedding(const Function &F,
+                                                IR2VecKind Kind) const;
 
   /// Get embeddings for all functions in the module
-  FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+  FuncEmbMap getFunctionEmbeddingsMap(IR2VecKind Kind) const;
 
   /// Dump entity ID to string mappings
   static void writeEntitiesToStream(raw_ostream &OS);

>From f876a9e55b882dce8afe55a1594bb51e8990bb6c Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 3 Feb 2026 15:14:39 +0530
Subject: [PATCH 4/7] Vocab errors should be fatal errors

---
 llvm/tools/llvm-ir2vec/lib/Utils.cpp | 33 +++++++++++-----------------
 1 file changed, 13 insertions(+), 20 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 6a772b6787f43..ea81525ce99fe 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -153,11 +153,9 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
 
 std::optional<Embedding>
 IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary not initialized properly.\n";
-    return std::nullopt;
-  }
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(errc::invalid_argument,
+                             "Failed to initialize IR2Vec vocabulary");
 
   if (F.isDeclaration())
     return std::nullopt;
@@ -172,11 +170,9 @@ IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
 }
 
 FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary not initialized properly.\n";
-    return {};
-  }
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(errc::invalid_argument,
+                             "Failed to initialize IR2Vec vocabulary");
 
   FuncEmbMap Result;
 
@@ -194,11 +190,9 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
 
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(errc::invalid_argument,
+                             "Failed to initialize IR2Vec vocabulary");
 
   for (const Function &F : M.getFunctionDefs())
     writeEmbeddingsToStream(F, OS, Level);
@@ -206,11 +200,10 @@ void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
 
 void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(errc::invalid_argument,
+                             "Failed to initialize IR2Vec vocabulary");
+
   if (F.isDeclaration()) {
     OS << "Function " << F.getName() << " is a declaration, skipping.\n";
     return;

>From 12b8545f14b40a51abdfc1b2cce21bacc93523b1 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 4 Feb 2026 13:17:08 +0530
Subject: [PATCH 5/7] Nit changes, cleaning up debug artifacts

---
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 14 +++--
 llvm/tools/llvm-ir2vec/lib/Utils.cpp         | 63 +++++++++++---------
 llvm/tools/llvm-ir2vec/lib/Utils.h           |  6 +-
 3 files changed, 46 insertions(+), 37 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 73bfa9511afcc..6da633c906d34 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -72,19 +72,23 @@ class PyIR2VecTool {
 
   nb::dict getFuncEmbMap() {
     auto ToolFuncEmbMap = Tool->getFunctionEmbeddingsMap(OutputEmbeddingMode);
+
+    if (!ToolFuncEmbMap)
+      throw nb::value_error(toString(ToolFuncEmbMap.takeError()).c_str());
+
     nb::dict NBFuncEmbMap;
 
-    for (const auto &[FuncPtr, FuncEmb] : ToolFuncEmbMap) {
+    for (const auto &[FuncPtr, FuncEmb] : *ToolFuncEmbMap) {
       std::string FuncName = FuncPtr->getName().str();
       auto Data = FuncEmb.getData();
-      size_t Shape[1] = {Data.size()};
       double *DataPtr = new double[Data.size()];
-      std::copy(Data.data(), Data.data() + Data.size(), DataPtr);
+      std::copy(Data.begin(), Data.end(), DataPtr);
 
       auto NbArray = nb::ndarray<nb::numpy, double>(
-          DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *P) noexcept {
-            delete[] static_cast<double *>(P);
+          DataPtr, {Data.size()}, nb::capsule(DataPtr, [](void *p) noexcept {
+            delete[] static_cast<double *>(p);
           }));
+
       NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
     }
 
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index ea81525ce99fe..50b71d6c134d6 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
 
 #include "llvm/CodeGen/MIR2Vec.h"
@@ -151,38 +152,40 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
     OS << Entities[EntityID] << '\t' << EntityID << '\n';
 }
 
-std::optional<Embedding>
-IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
+Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
+                                                     IR2VecKind Kind) const {
   if (!Vocab || !Vocab->isValid())
-    return createStringError(errc::invalid_argument,
-                             "Failed to initialize IR2Vec vocabulary");
+    return createStringError(
+        errc::invalid_argument,
+        "Vocabulary is not valid. IR2VecTool not initialized.");
 
   if (F.isDeclaration())
-    return std::nullopt;
+    return createStringError(errc::invalid_argument,
+                             "Function is a declaration.");
 
   auto Emb = Embedder::create(Kind, F, *Vocab);
-  if (!Emb) {
-    WithColor::error(errs(), ToolName)
-        << "Failed to create embedder for function '" << F.getName() << "'.\n";
-    return std::nullopt;
-  }
+  if (!Emb)
+    return createStringError(errc::invalid_argument,
+                             "Failed to create embedder for function '%s'.",
+                             F.getName().str().c_str());
+
   return Emb->getFunctionVector();
 }
 
-FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
+Expected<FuncEmbMap>
+IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
   if (!Vocab || !Vocab->isValid())
-    return createStringError(errc::invalid_argument,
-                             "Failed to initialize IR2Vec vocabulary");
+    return createStringError(
+        errc::invalid_argument,
+        "Vocabulary is not valid. IR2VecTool not initialized.");
 
   FuncEmbMap Result;
 
-  for (const Function &F : M) {
-    if (F.isDeclaration())
-      continue;
-
+  for (const Function &F : M.getFunctionDefs()) {
     auto Emb = getFunctionEmbedding(F, Kind);
-    if (Emb.has_value())
-      Result.try_emplace(&F, std::move(*Emb));
+    if (!Emb)
+      return Emb.takeError();
+    Result.try_emplace(&F, std::move(*Emb));
   }
 
   return Result;
@@ -190,9 +193,11 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
 
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid())
-    return createStringError(errc::invalid_argument,
-                             "Failed to initialize IR2Vec vocabulary");
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
 
   for (const Function &F : M.getFunctionDefs())
     writeEmbeddingsToStream(F, OS, Level);
@@ -200,9 +205,11 @@ void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
 
 void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid())
-    return createStringError(errc::invalid_argument,
-                             "Failed to initialize IR2Vec vocabulary");
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
 
   if (F.isDeclaration()) {
     OS << "Function " << F.getName() << " is a declaration, skipping.\n";
@@ -395,8 +402,7 @@ void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
 void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
                                           EmbeddingLevel Level) const {
   if (!Vocab) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary not initialized properly.\n";
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
     return;
   }
 
@@ -415,8 +421,7 @@ void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
 void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
                                           EmbeddingLevel Level) const {
   if (!Vocab) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary not initialized properly.\n";
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
     return;
   }
 
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index da2332a06af29..d535f5fd5bb74 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -115,11 +115,11 @@ class IR2VecTool {
   static EntityList collectEntityMappings();
 
   // Get embedding for a single function
-  std::optional<Embedding> getFunctionEmbedding(const Function &F,
-                                                IR2VecKind Kind) const;
+  Expected<Embedding> getFunctionEmbedding(const Function &F,
+                                           IR2VecKind Kind) const;
 
   /// Get embeddings for all functions in the module
-  FuncEmbMap getFunctionEmbeddingsMap(IR2VecKind Kind) const;
+  Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
 
   /// Dump entity ID to string mappings
   static void writeEntitiesToStream(raw_ostream &OS);

>From 87989d40fd4858d378702860b01a23cfd311e867 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 4 Feb 2026 13:23:19 +0530
Subject: [PATCH 6/7] Nit commit

---
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 6da633c906d34..015a57e622758 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -79,7 +79,6 @@ class PyIR2VecTool {
     nb::dict NBFuncEmbMap;
 
     for (const auto &[FuncPtr, FuncEmb] : *ToolFuncEmbMap) {
-      std::string FuncName = FuncPtr->getName().str();
       auto Data = FuncEmb.getData();
       double *DataPtr = new double[Data.size()];
       std::copy(Data.begin(), Data.end(), DataPtr);
@@ -89,7 +88,7 @@ class PyIR2VecTool {
             delete[] static_cast<double *>(p);
           }));
 
-      NBFuncEmbMap[nb::str(FuncName.c_str())] = NbArray;
+      NBFuncEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbArray;
     }
 
     return NBFuncEmbMap;

>From 5c296c1f89c93a3e4131ed93957277ef88bf4d0c Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:32:36 +0530
Subject: [PATCH 7/7] adding BB embedding map API to ir2vec python bindings

---
 .../llvm-ir2vec/bindings/ir2vec-bindings.py   | 23 ++++++++++++
 llvm/tools/llvm-ir2vec/lib/Utils.cpp          | 35 +++++++++++++++++++
 llvm/tools/llvm-ir2vec/lib/Utils.h            |  6 ++++
 3 files changed, 64 insertions(+)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index a209a47cba42e..1173a8b1ed2d9 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -22,6 +22,16 @@
         print(f"Function: {func_name}")
         print(f"  Embedding: {emb.tolist()}")
 
+    # Test getBBEmbMap
+    print("\n=== Basic Block Embeddings ===")
+    bb_emb_list = tool.getBBEmbMap()
+
+    # Sorting by BB name for deterministic output
+    bb_sorted = sorted(bb_emb_list, key=lambda x: (x[0], tuple(x[1].tolist())))
+    for bb_name, emb in bb_sorted:
+        print(f"BB: {bb_name}")
+        print(f"  Embedding: {emb.tolist()}")
+
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
 # CHECK: === Function Embeddings ===
@@ -31,3 +41,16 @@
 # CHECK-NEXT:   Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
 # CHECK: Function: multiply
 # CHECK-NEXT:   Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Basic Block Embeddings ===
+# CHECK: BB: entry
+# CHECK-NEXT:   Embedding: [38.0, 40.0, 42.0]
+# CHECK: BB: entry
+# CHECK-NEXT:   Embedding: [50.0, 52.0, 54.0]
+# CHECK: BB: entry
+# CHECK-NEXT:   Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
+# CHECK: BB: exit
+# CHECK-NEXT:   Embedding: [164.0, 166.0, 168.0]
+# CHECK: BB: negative
+# CHECK-NEXT:   Embedding: [47.0, 49.0, 51.0]
+# CHECK: BB: positive
+# CHECK-NEXT:   Embedding: [41.0, 43.0, 45.0]
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 50b71d6c134d6..76fc85b1afe2c 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -191,6 +191,41 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
   return Result;
 }
 
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(const Function &F,
+                                            IR2VecKind Kind) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  BBEmbeddingsMap Result;
+
+  if (F.isDeclaration())
+    return Result;
+
+  auto Emb = Embedder::create(Kind, F, *Vocab);
+  if (!Emb)
+    return Result;
+
+  for (const BasicBlock &BB : F)
+    Result.try_emplace(&BB, Emb->getBBVector(BB));
+
+  return Result;
+}
+
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(IR2VecKind Kind) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  BBEmbeddingsMap Result;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    BBEmbeddingsMap FuncBBVecs = getBBEmbeddings(F, Kind);
+    Result.insert(FuncBBVecs.begin(), FuncBBVecs.end());
+  }
+
+  return Result;
+}
+
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
   if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d535f5fd5bb74..30d4a1a095ac1 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -121,6 +121,12 @@ class IR2VecTool {
   /// Get embeddings for all functions in the module
   Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
 
+  /// Get embeddings for all basic blocks in a function
+  BBEmbeddingsMap getBBEmbeddings(const Function &F, IR2VecKind Kind) const;
+
+  /// Get embeddings for all basic blocks in the module
+  BBEmbeddingsMap getBBEmbeddings(IR2VecKind Kind) const;
+
   /// Dump entity ID to string mappings
   static void writeEntitiesToStream(raw_ostream &OS);
 



More information about the llvm-commits mailing list