[llvm] [llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (PR #180140)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 11 22:00:56 PST 2026


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/180140

>From 1ab726cfbb4059c2f42a40591d441ad2427c41e5 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:45:57 +0530
Subject: [PATCH 1/5] Adding Inst Embeddings Map API to ir2vec python bindings
 - returns a nested map indexed by functions

---
 .../llvm-ir2vec/bindings/ir2vec-bindings.py   | 39 ++++++++++++++
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp  | 51 +++++++++++++++++++
 llvm/tools/llvm-ir2vec/lib/Utils.cpp          | 45 ++++++++++++++++
 llvm/tools/llvm-ir2vec/lib/Utils.h            |  9 ++++
 4 files changed, 144 insertions(+)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 502f8a2411aa8..78c322d6d9265 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -43,6 +43,22 @@
             print(f"  BB: {bb_name}")
             print(f"    Embedding: {emb.tolist()}")
 
+    # Test getInstEmbMap
+    print("\n=== Instruction Embeddings ===")
+    inst_emb_map = tool.getInstEmbMap()
+
+    # Sorting by function name, then instruction string for deterministic output
+    inst_sorted = []
+    for func_name in sorted(inst_emb_map.keys()):
+        func_inst_map = inst_emb_map[func_name]
+        for inst_str in sorted(func_inst_map.keys()):
+            emb = func_inst_map[inst_str]
+            inst_sorted.append((inst_str, emb))
+    
+    for inst_str, emb in inst_sorted:
+        print(f"Inst: {inst_str}")
+        print(f"  Embedding: {emb.tolist()}")
+
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
 # CHECK: === Function Embeddings ===
@@ -75,3 +91,26 @@
 # CHECK: Function: multiply
 # CHECK:   BB: entry
 # CHECK-NEXT:     Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT:   Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT:   Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT:   Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT:   Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT:   Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT:   Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT:   Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT:   Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e334997b31394..dc47c279313bd 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -150,6 +150,41 @@ class PyIR2VecTool {
 
     return NbBBEmbMap;
   }
+
+  nb::dict getInstEmbMap() {
+    auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+
+    if (!ToolInstEmbMap)
+      throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
+
+    nb::dict NbInstEmbMap;
+
+    for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
+      nb::dict NbFuncInstMap;
+
+      for (const auto &[InstPtr, InstEmb] : InstMap) {
+        auto InstEmbVec = InstEmb.getData();
+        double *NbInstEmbVec = new double[InstEmbVec.size()];
+        std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+
+        auto NbArray = nb::ndarray<nb::numpy, double>(
+            NbInstEmbVec, {InstEmbVec.size()},
+            nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+              delete[] static_cast<double *>(P);
+            }));
+
+        std::string InstStr;
+        raw_string_ostream OS(InstStr);
+        InstPtr->print(OS);
+
+        NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
+      }
+
+      NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+    }
+
+    return NbInstEmbMap;
+  }
 };
 
 } // namespace
@@ -169,11 +204,27 @@ NB_MODULE(ir2vec, m) {
            "Generate embedding for a single function by name\n"
            "Args: funcName (str) - IR-Name of the function\n"
            "Returns: ndarray[float64] - Function embedding vector")
+<<<<<<< HEAD
       .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
            "Generate embeddings for all basic blocks in a function\n"
            "Args: funcName (str) - IR-Name of the function\n"
            "Returns: dict[str, ndarray[float64]] - "
            "{basic_block_name: embedding vector}");
+=======
+      .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
+           "Generate embeddings for all basic blocks in the module\n"
+           "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+           "dictionary mapping "
+           "function names to dictionaries of basic block names to embedding "
+           "vectors")
+      .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
+           "Generate embeddings for all instructions in the module\n"
+           "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
+           "dictionary mapping "
+           "function names to dictionaries of instruction strings to embedding "
+           "vectors");
+
+>>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
   m.def(
       "initEmbedding",
       [](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 70f501c116407..4215a6073a179 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -216,6 +216,51 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
   return Result;
 }
 
+Expected<InstEmbeddingsMap>
+IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(
+        errc::invalid_argument,
+        "Vocabulary is not valid. IR2VecTool not initialized.");
+
+  InstEmbeddingsMap Result;
+
+  if (F.isDeclaration())
+    return createStringError(errc::invalid_argument,
+                             "Function is a declaration.");
+
+  auto Emb = Embedder::create(Kind, F, *Vocab);
+  if (!Emb)
+    return createStringError(errc::invalid_argument,
+                             "Failed to create embedder for function '%s'.",
+                             F.getName().str().c_str());
+
+  for (const Instruction &I : instructions(F))
+    Result.try_emplace(&I, Emb->getInstVector(I));
+
+  return Result;
+}
+
+Expected<FuncInstEmbMap>
+IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+  if (!Vocab || !Vocab->isValid())
+    return createStringError(
+        errc::invalid_argument,
+        "Vocabulary is not valid. IR2VecTool not initialized.");
+
+  FuncInstEmbMap Result;
+
+  for (const Function &F : M.getFunctionDefs()) {
+    auto FuncInstVecs = getInstEmbMap(F, Kind);
+    if (!FuncInstVecs)
+      return FuncInstVecs.takeError();
+
+    Result.try_emplace(&F, std::move(*FuncInstVecs));
+  }
+
+  return Result;
+}
+
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
   if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 04ec4a74b1e24..806cb5406b122 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,6 +74,7 @@ struct TripletResult {
 /// Entity mappings: [entity_name]
 using EntityList = std::vector<std::string>;
 using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
+using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
 
 namespace ir2vec {
 
@@ -127,6 +128,14 @@ class IR2VecTool {
   /// Get embeddings for all basic blocks in a function
   Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
                                                IR2VecKind Kind) const;
+  /// Get embeddings for all instructions in a function
+  Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+
+  /// Get embeddings for all instructions in the module, organized by function
+  Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
+
+  /// Dump entity ID to string mappings
+  static void writeEntitiesToStream(raw_ostream &OS);
 
   /// Generate embeddings for the entire module
   void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;

>From 1288aadaacce0a4cccb14d26a48a14e901d02a53 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 6 Feb 2026 13:59:01 +0530
Subject: [PATCH 2/5] nit commit - formatting fixup

---
 llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
 llvm/tools/llvm-ir2vec/lib/Utils.cpp                    | 3 +--
 llvm/tools/llvm-ir2vec/lib/Utils.h                      | 3 ++-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 78c322d6d9265..52ad69c7e04d6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -54,7 +54,7 @@
         for inst_str in sorted(func_inst_map.keys()):
             emb = func_inst_map[inst_str]
             inst_sorted.append((inst_str, emb))
-    
+
     for inst_str, emb in inst_sorted:
         print(f"Inst: {inst_str}")
         print(f"  Embedding: {emb.tolist()}")
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4215a6073a179..739cfed4bb497 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -241,8 +241,7 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
   return Result;
 }
 
-Expected<FuncInstEmbMap>
-IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
+Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
   if (!Vocab || !Vocab->isValid())
     return createStringError(
         errc::invalid_argument,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 806cb5406b122..e01b4d1baa6ee 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,8 @@ class IR2VecTool {
   Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
                                                IR2VecKind Kind) const;
   /// Get embeddings for all instructions in a function
-  Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F, IR2VecKind Kind) const;
+  Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+                                            IR2VecKind Kind) const;
 
   /// Get embeddings for all instructions in the module, organized by function
   Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;

>From dabd7e3147ecd2dd6a86c4222fb2b2df9941c664 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:14:28 +0530
Subject: [PATCH 3/5] Changing Inst Embedding Map API to limit to a specific
 function

---
 .../llvm-ir2vec/bindings/ir2vec-bindings.py   | 69 +++++++++----------
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp  | 66 ++++++++----------
 llvm/tools/llvm-ir2vec/lib/Utils.cpp          | 21 +-----
 llvm/tools/llvm-ir2vec/lib/Utils.h            |  9 +--
 4 files changed, 64 insertions(+), 101 deletions(-)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 52ad69c7e04d6..f4906c33e3619 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -45,19 +45,15 @@
 
     # Test getInstEmbMap
     print("\n=== Instruction Embeddings ===")
-    inst_emb_map = tool.getInstEmbMap()
-
-    # Sorting by function name, then instruction string for deterministic output
-    inst_sorted = []
-    for func_name in sorted(inst_emb_map.keys()):
-        func_inst_map = inst_emb_map[func_name]
-        for inst_str in sorted(func_inst_map.keys()):
-            emb = func_inst_map[inst_str]
-            inst_sorted.append((inst_str, emb))
-
-    for inst_str, emb in inst_sorted:
-        print(f"Inst: {inst_str}")
-        print(f"  Embedding: {emb.tolist()}")
+    
+    # Test valid function names in sorted order
+    for func_name in sorted(["add", "multiply", "conditional"]):
+        inst_emb_map = tool.getInstEmbMap(func_name)
+        print(f"Function: {func_name}")
+        for inst_str in sorted(inst_emb_map.keys()):
+            emb = inst_emb_map[inst_str]
+            print(f"  Inst: {inst_str}")
+            print(f"    Embedding: {emb.tolist()}")
 
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
@@ -92,25 +88,28 @@
 # CHECK:   BB: entry
 # CHECK-NEXT:     Embedding: [50.0, 52.0, 54.0]
 # CHECK: === Instruction Embeddings ===
-# CHECK: Inst: %sum = add i32 %a, %b
-# CHECK-NEXT:   Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: ret i32 %sum
-# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
-# CHECK-NEXT:   Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
-# CHECK: Inst: %neg_val = sub i32 %n, 10
-# CHECK-NEXT:   Embedding: [43.0, 44.0, 45.0]
-# CHECK: Inst: %pos_val = add i32 %n, 10
-# CHECK-NEXT:   Embedding: [37.0, 38.0, 39.0]
-# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
-# CHECK-NEXT:   Embedding: [163.0, 164.0, 165.0]
-# CHECK: Inst: br i1 %cmp, label %positive, label %negative
-# CHECK-NEXT:   Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: br label %exit
-# CHECK-NEXT:   Embedding: [4.0, 5.0, 6.0]
-# CHECK: Inst: ret i32 %result
-# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
-# CHECK: Inst: %prod = mul i32 %x, %y
-# CHECK-NEXT:   Embedding: [49.0, 50.0, 51.0]
-# CHECK: Inst: ret i32 %prod
-# CHECK-NEXT:   Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: add
+# CHECK:   Inst: %sum = add i32 %a, %b
+# CHECK-NEXT:     Embedding: [37.0, 38.0, 39.0]
+# CHECK:   Inst: ret i32 %sum
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: conditional
+# CHECK:   Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT:     Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK:   Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT:     Embedding: [43.0, 44.0, 45.0]
+# CHECK:   Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT:     Embedding: [37.0, 38.0, 39.0]
+# CHECK:   Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT:     Embedding: [163.0, 164.0, 165.0]
+# CHECK:   Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT:     Embedding: [4.0, 5.0, 6.0]
+# CHECK:   Inst: br label %exit
+# CHECK-NEXT:     Embedding: [4.0, 5.0, 6.0]
+# CHECK:   Inst: ret i32 %result
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: multiply
+# CHECK:   Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT:     Embedding: [49.0, 50.0, 51.0]
+# CHECK:   Inst: ret i32 %prod
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index dc47c279313bd..e4ddaf9c14e5a 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -151,36 +151,36 @@ class PyIR2VecTool {
     return NbBBEmbMap;
   }
 
-  nb::dict getInstEmbMap() {
-    auto ToolInstEmbMap = Tool->getFuncInstEmbMap(OutputEmbeddingMode);
+  nb::dict getInstEmbMap(const std::string &FuncName) {
+    const Function *F = M->getFunction(FuncName);
+
+    if (!F)
+      throw nb::value_error(
+          ("Function '" + FuncName + "' not found in module").c_str());
+
+    auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
 
     if (!ToolInstEmbMap)
       throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
 
     nb::dict NbInstEmbMap;
 
-    for (const auto &[FuncPtr, InstMap] : *ToolInstEmbMap) {
-      nb::dict NbFuncInstMap;
-
-      for (const auto &[InstPtr, InstEmb] : InstMap) {
-        auto InstEmbVec = InstEmb.getData();
-        double *NbInstEmbVec = new double[InstEmbVec.size()];
-        std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+    for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
+      auto InstEmbVec = InstEmb.getData();
+      double *NbInstEmbVec = new double[InstEmbVec.size()];
+      std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
 
-        auto NbArray = nb::ndarray<nb::numpy, double>(
-            NbInstEmbVec, {InstEmbVec.size()},
-            nb::capsule(NbInstEmbVec, [](void *P) noexcept {
-              delete[] static_cast<double *>(P);
-            }));
-
-        std::string InstStr;
-        raw_string_ostream OS(InstStr);
-        InstPtr->print(OS);
+      auto NbArray = nb::ndarray<nb::numpy, double>(
+          NbInstEmbVec, {InstEmbVec.size()},
+          nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+            delete[] static_cast<double *>(P);
+          }));
 
-        NbFuncInstMap[nb::str(OS.str().c_str())] = NbArray;
-      }
+      std::string InstStr;
+      raw_string_ostream OS(InstStr);
+      InstPtr->print(OS);
 
-      NbInstEmbMap[nb::str(FuncPtr->getName().str().c_str())] = NbFuncInstMap;
+      NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
     }
 
     return NbInstEmbMap;
@@ -204,27 +204,17 @@ NB_MODULE(ir2vec, m) {
            "Generate embedding for a single function by name\n"
            "Args: funcName (str) - IR-Name of the function\n"
            "Returns: ndarray[float64] - Function embedding vector")
-<<<<<<< HEAD
       .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap, nb::arg("funcName"),
            "Generate embeddings for all basic blocks in a function\n"
            "Args: funcName (str) - IR-Name of the function\n"
            "Returns: dict[str, ndarray[float64]] - "
-           "{basic_block_name: embedding vector}");
-=======
-      .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
-           "Generate embeddings for all basic blocks in the module\n"
-           "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
-           "dictionary mapping "
-           "function names to dictionaries of basic block names to embedding "
-           "vectors")
-      .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
-           "Generate embeddings for all instructions in the module\n"
-           "Returns: dict[str, dict[str, ndarray[float64]]] - Nested "
-           "dictionary mapping "
-           "function names to dictionaries of instruction strings to embedding "
-           "vectors");
-
->>>>>>> b259e51c5fdd (Adding Inst Embeddings Map API to ir2vec python bindings - returns a nested map indexed by functions)
+           "{basic_block_name: embedding vector}")
+      .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
+           "Generate embeddings for all instructions in a function\n"
+           "Args: funcName (str) - IR-Name of the function\n"
+           "Returns: dict[str, ndarray[float64]] - "
+           "{instruction_string: embedding_vector}");
+
   m.def(
       "initEmbedding",
       [](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 739cfed4bb497..4d1223b6ab1a5 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -217,7 +217,7 @@ IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
 }
 
 Expected<InstEmbeddingsMap>
-IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
+IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
   if (!Vocab || !Vocab->isValid())
     return createStringError(
         errc::invalid_argument,
@@ -241,25 +241,6 @@ IR2VecTool::getInstEmbMap(const Function &F, IR2VecKind Kind) const {
   return Result;
 }
 
-Expected<FuncInstEmbMap> IR2VecTool::getFuncInstEmbMap(IR2VecKind Kind) const {
-  if (!Vocab || !Vocab->isValid())
-    return createStringError(
-        errc::invalid_argument,
-        "Vocabulary is not valid. IR2VecTool not initialized.");
-
-  FuncInstEmbMap Result;
-
-  for (const Function &F : M.getFunctionDefs()) {
-    auto FuncInstVecs = getInstEmbMap(F, Kind);
-    if (!FuncInstVecs)
-      return FuncInstVecs.takeError();
-
-    Result.try_emplace(&F, std::move(*FuncInstVecs));
-  }
-
-  return Result;
-}
-
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
   if (!Vocab || !Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index e01b4d1baa6ee..639f37e74466b 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -74,7 +74,6 @@ struct TripletResult {
 /// Entity mappings: [entity_name]
 using EntityList = std::vector<std::string>;
 using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
-using FuncInstEmbMap = DenseMap<const Function *, ir2vec::InstEmbeddingsMap>;
 
 namespace ir2vec {
 
@@ -129,15 +128,9 @@ class IR2VecTool {
   Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
                                                IR2VecKind Kind) const;
   /// Get embeddings for all instructions in a function
-  Expected<InstEmbeddingsMap> getInstEmbMap(const Function &F,
+  Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
                                             IR2VecKind Kind) const;
 
-  /// Get embeddings for all instructions in the module, organized by function
-  Expected<FuncInstEmbMap> getFuncInstEmbMap(IR2VecKind Kind) const;
-
-  /// Dump entity ID to string mappings
-  static void writeEntitiesToStream(raw_ostream &OS);
-
   /// Generate embeddings for the entire module
   void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
 

>From 6faf88c6f9f7141508419475ce09db0d012bdea8 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:28:01 +0530
Subject: [PATCH 4/5] nit commit - formatting fixup

---
 llvm/tools/llvm-ir2vec/lib/Utils.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 639f37e74466b..face23895cff6 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -129,7 +129,7 @@ class IR2VecTool {
                                                IR2VecKind Kind) const;
   /// Get embeddings for all instructions in a function
   Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
-                                            IR2VecKind Kind) const;
+                                                   IR2VecKind Kind) const;
 
   /// Generate embeddings for the entire module
   void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;

>From bfa08ad70fca786aa048d2eb85e020b3ad50c303 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 9 Feb 2026 12:31:05 +0530
Subject: [PATCH 5/5] nit commit - python file formatting fixup

---
 llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index f4906c33e3619..bb29d33dc8ca6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -45,7 +45,7 @@
 
     # Test getInstEmbMap
     print("\n=== Instruction Embeddings ===")
-    
+
     # Test valid function names in sorted order
     for func_name in sorted(["add", "multiply", "conditional"]):
         inst_emb_map = tool.getInstEmbMap(func_name)



More information about the llvm-commits mailing list