[llvm] 673d725 - Reland "[llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (#180140)" (#184196)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 5 09:03:19 PST 2026


Author: Nishant Sachdeva
Date: 2026-03-05T22:33:15+05:30
New Revision: 673d725c3bcf5b551e42ff96d00b77b8f3b9adcb

URL: https://github.com/llvm/llvm-project/commit/673d725c3bcf5b551e42ff96d00b77b8f3b9adcb
DIFF: https://github.com/llvm/llvm-project/commit/673d725c3bcf5b551e42ff96d00b77b8f3b9adcb.diff

LOG: Reland "[llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (#180140)" (#184196)

Relanding change from https://github.com/llvm/llvm-project/pull/180140

- Returns a Inst Embedding Map based on the input function name
 `getInstEmbMap(funcName) -> Map<Inst string, Embedding>`

- Refactors IR2VecTool methods to have a separate call to create the
embedder object

Added: 
    

Modified: 
    llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
    llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
    llvm/tools/llvm-ir2vec/lib/Utils.cpp
    llvm/tools/llvm-ir2vec/lib/Utils.h

Removed: 
    


################################################################################
diff  --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 502f8a2411aa8..bb29d33dc8ca6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -43,6 +43,18 @@
             print(f"  BB: {bb_name}")
             print(f"    Embedding: {emb.tolist()}")
 
+    # Test getInstEmbMap
+    print("\n=== Instruction Embeddings ===")
+
+    # Test valid function names in sorted order
+    for func_name in sorted(["add", "multiply", "conditional"]):
+        inst_emb_map = tool.getInstEmbMap(func_name)
+        print(f"Function: {func_name}")
+        for inst_str in sorted(inst_emb_map.keys()):
+            emb = inst_emb_map[inst_str]
+            print(f"  Inst: {inst_str}")
+            print(f"    Embedding: {emb.tolist()}")
+
 # CHECK: SUCCESS: Tool initialized
 # CHECK: Tool type: IR2VecTool
 # CHECK: === Function Embeddings ===
@@ -75,3 +87,29 @@
 # CHECK: Function: multiply
 # CHECK:   BB: entry
 # CHECK-NEXT:     Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Function: add
+# CHECK:   Inst: %sum = add i32 %a, %b
+# CHECK-NEXT:     Embedding: [37.0, 38.0, 39.0]
+# CHECK:   Inst: ret i32 %sum
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: conditional
+# CHECK:   Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT:     Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK:   Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT:     Embedding: [43.0, 44.0, 45.0]
+# CHECK:   Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT:     Embedding: [37.0, 38.0, 39.0]
+# CHECK:   Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT:     Embedding: [163.0, 164.0, 165.0]
+# CHECK:   Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT:     Embedding: [4.0, 5.0, 6.0]
+# CHECK:   Inst: br label %exit
+# CHECK-NEXT:     Embedding: [4.0, 5.0, 6.0]
+# CHECK:   Inst: ret i32 %result
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: multiply
+# CHECK:   Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT:     Embedding: [49.0, 50.0, 51.0]
+# CHECK:   Inst: ret i32 %prod
+# CHECK-NEXT:     Embedding: [1.0, 2.0, 3.0]

diff  --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e334997b31394..e4ddaf9c14e5a 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -150,6 +150,41 @@ class PyIR2VecTool {
 
     return NbBBEmbMap;
   }
+
+  nb::dict getInstEmbMap(const std::string &FuncName) {
+    const Function *F = M->getFunction(FuncName);
+
+    if (!F)
+      throw nb::value_error(
+          ("Function '" + FuncName + "' not found in module").c_str());
+
+    auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
+
+    if (!ToolInstEmbMap)
+      throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
+
+    nb::dict NbInstEmbMap;
+
+    for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
+      auto InstEmbVec = InstEmb.getData();
+      double *NbInstEmbVec = new double[InstEmbVec.size()];
+      std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+
+      auto NbArray = nb::ndarray<nb::numpy, double>(
+          NbInstEmbVec, {InstEmbVec.size()},
+          nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+            delete[] static_cast<double *>(P);
+          }));
+
+      std::string InstStr;
+      raw_string_ostream OS(InstStr);
+      InstPtr->print(OS);
+
+      NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
+    }
+
+    return NbInstEmbMap;
+  }
 };
 
 } // namespace
@@ -173,7 +208,13 @@ NB_MODULE(ir2vec, m) {
            "Generate embeddings for all basic blocks in a function\n"
            "Args: funcName (str) - IR-Name of the function\n"
            "Returns: dict[str, ndarray[float64]] - "
-           "{basic_block_name: embedding vector}");
+           "{basic_block_name: embedding vector}")
+      .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
+           "Generate embeddings for all instructions in a function\n"
+           "Args: funcName (str) - IR-Name of the function\n"
+           "Returns: dict[str, ndarray[float64]] - "
+           "{instruction_string: embedding_vector}");
+
   m.def(
       "initEmbedding",
       [](const std::string &filename, const std::string &mode,

diff  --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 70f501c116407..c508f8b8b5b46 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -152,8 +152,8 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
     OS << Entities[EntityID] << '\t' << EntityID << '\n';
 }
 
-Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
-                                                     IR2VecKind Kind) const {
+Expected<std::unique_ptr<Embedder>>
+IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
   if (!Vocab || !Vocab->isValid())
     return createStringError(
         errc::invalid_argument,
@@ -169,16 +169,20 @@ Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
                              "Failed to create embedder for function '%s'.",
                              F.getName().str().c_str());
 
-  return Emb->getFunctionVector();
+  return std::move(Emb);
+}
+
+Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
+                                                     IR2VecKind Kind) const {
+  auto Emb = createIR2VecEmbedder(F, Kind);
+  if (!Emb)
+    return Emb.takeError();
+
+  return (*Emb)->getFunctionVector();
 }
 
 Expected<FuncEmbMap>
 IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
-  if (!Vocab || !Vocab->isValid())
-    return createStringError(
-        errc::invalid_argument,
-        "Vocabulary is not valid. IR2VecTool not initialized.");
-
   FuncEmbMap Result;
 
   for (const Function &F : M.getFunctionDefs()) {
@@ -193,61 +197,47 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
 
 Expected<BBEmbeddingsMap>
 IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
-  if (!Vocab || !Vocab->isValid())
-    return createStringError(
-        errc::invalid_argument,
-        "Vocabulary is not valid. IR2VecTool not initialized.");
+  auto Emb = createIR2VecEmbedder(F, Kind);
+  if (!Emb)
+    return Emb.takeError();
 
   BBEmbeddingsMap Result;
 
-  if (F.isDeclaration())
-    return createStringError(errc::invalid_argument,
-                             "Function is a declaration.");
+  for (const BasicBlock &BB : F)
+    Result.try_emplace(&BB, (*Emb)->getBBVector(BB));
 
-  auto Emb = Embedder::create(Kind, F, *Vocab);
+  return Result;
+}
+
+Expected<InstEmbeddingsMap>
+IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
+  auto Emb = createIR2VecEmbedder(F, Kind);
   if (!Emb)
-    return createStringError(errc::invalid_argument,
-                             "Failed to create embedder for function '%s'.",
-                             F.getName().str().c_str());
+    return Emb.takeError();
 
-  for (const BasicBlock &BB : F)
-    Result.try_emplace(&BB, Emb->getBBVector(BB));
+  InstEmbeddingsMap Result;
+
+  for (const Instruction &I : instructions(F))
+    Result.try_emplace(&I, (*Emb)->getInstVector(I));
 
   return Result;
 }
 
 void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
-
   for (const Function &F : M.getFunctionDefs())
     writeEmbeddingsToStream(F, OS, Level);
 }
 
 void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
                                          EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
-
-  if (F.isDeclaration()) {
-    OS << "Function " << F.getName() << " is a declaration, skipping.\n";
-    return;
-  }
-
-  // Create embedder for this function
-  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-  if (!Emb) {
+  auto IR2VecEmbedderObj = createIR2VecEmbedder(F, IR2VecEmbeddingKind);
+  if (!IR2VecEmbedderObj) {
     WithColor::error(errs(), ToolName)
-        << "Failed to create embedder for function " << F.getName() << "\n";
+        << toString(IR2VecEmbedderObj.takeError()) << "\n";
     return;
   }
+  auto Emb = std::move(*IR2VecEmbedderObj);
 
   OS << "Function: " << F.getName() << "\n";
 

diff  --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 04ec4a74b1e24..ae2f931a90cf9 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -94,6 +94,10 @@ class IR2VecTool {
 public:
   explicit IR2VecTool(Module &M) : M(M) {}
 
+  /// Creates the embedding object for downstream embedding streaming
+  Expected<std::unique_ptr<Embedder>>
+  createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
+
   /// Initialize the IR2Vec vocabulary from the specified file path.
   Error initializeVocabulary(StringRef VocabPath);
 
@@ -127,6 +131,9 @@ class IR2VecTool {
   /// Get embeddings for all basic blocks in a function
   Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
                                                IR2VecKind Kind) const;
+  /// Get embeddings for all instructions in a function
+  Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
+                                                   IR2VecKind Kind) const;
 
   /// Generate embeddings for the entire module
   void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;


        


More information about the llvm-commits mailing list