[llvm] [IR2Vec] Introducing python bindings for IR2Vec (PR #173194)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 23 11:18:32 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/173194

>From 958aafd1475e75351b38cd850ee5441b431c427d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Sun, 21 Dec 2025 01:25:33 +0530
Subject: [PATCH 1/5] Work Commit - Separating all tool implementation from cli
 file

---
 llvm/tools/llvm-ir2vec/CMakeLists.txt         |   1 +
 llvm/tools/llvm-ir2vec/emb-tool.cpp           | 421 ++++++++++++++++++
 .../llvm-ir2vec/{llvm-ir2vec.h => emb-tool.h} |   8 +-
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        | 372 +---------------
 4 files changed, 427 insertions(+), 375 deletions(-)
 create mode 100644 llvm/tools/llvm-ir2vec/emb-tool.cpp
 rename llvm/tools/llvm-ir2vec/{llvm-ir2vec.h => emb-tool.h} (98%)

diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index 2bb6686392907..9d5db8663fb38 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -19,6 +19,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_tool(llvm-ir2vec
   llvm-ir2vec.cpp
+  emb-tool.cpp
   
   DEPENDS
   intrinsics_gen
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.cpp b/llvm/tools/llvm-ir2vec/emb-tool.cpp
new file mode 100644
index 0000000000000..891b26f8ef763
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/emb-tool.cpp
@@ -0,0 +1,421 @@
+//===- emb-tool.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2VecTool and MIR2VecTool classes for
+/// IR2Vec/MIR2Vec embedding generation.
+///
+//===----------------------------------------------------------------------===//
+
+#include "emb-tool.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Target/TargetMachine.h"
+
+#define DEBUG_TYPE "ir2vec"
+
+namespace llvm {
+
+namespace ir2vec {
+
+bool IR2VecTool::initializeVocabulary() {
+  // Register and run the IR2Vec vocabulary analysis
+  // The vocabulary file path is specified via --ir2vec-vocab-path global
+  // option
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+  // This will throw an error if vocab is not found or invalid
+  Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+  return Vocab->isValid();
+}
+
+TripletResult IR2VecTool::generateTriplets(const Function &F) const {
+  if (F.isDeclaration())
+    return {};
+
+  TripletResult Result;
+  Result.MaxRelation = 0;
+
+  unsigned MaxRelation = NextRelation;
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+
+  for (const BasicBlock &BB : F) {
+    for (const auto &I : BB.instructionsWithoutDebug()) {
+      unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+      unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+      // Add "Next" relationship with previous instruction
+      if (HasPrevOpcode) {
+        Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+        LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
+                          << '\t'
+                          << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
+                          << '\t' << "Next\n");
+      }
+
+      // Add "Type" relationship
+      Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+      LLVM_DEBUG(
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                 << '\t' << "Type\n");
+
+      // Add "Arg" relationships
+      unsigned ArgIndex = 0;
+      for (const Use &U : I.operands()) {
+        unsigned OperandID = Vocabulary::getIndex(*U.get());
+        unsigned RelationID = ArgRelation + ArgIndex;
+        Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+        LLVM_DEBUG({
+          StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+              Vocabulary::getOperandKind(U.get()));
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+      // Only update MaxRelation if there were operands
+      if (ArgIndex > 0)
+        MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+      PrevOpcode = Opcode;
+      HasPrevOpcode = true;
+    }
+  }
+
+  Result.MaxRelation = MaxRelation;
+  return Result;
+}
+
+TripletResult IR2VecTool::generateTriplets() const {
+  TripletResult Result;
+  Result.MaxRelation = NextRelation;
+
+  for (const Function &F : M.getFunctionDefs()) {
+    TripletResult FuncResult = generateTriplets(F);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                           FuncResult.Triplets.end());
+  }
+
+  return Result;
+}
+
+void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
+  auto Result = generateTriplets();
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+  for (const auto &T : Result.Triplets)
+    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList IR2VecTool::collectEntityMappings() {
+  auto EntityLen = Vocabulary::getCanonicalSize();
+  EntityList Result;
+  for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+    Result.push_back(Vocabulary::getStringKey(EntityID).str());
+  return Result;
+}
+
+void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
+  auto Entities = collectEntityMappings();
+  OS << Entities.size() << "\n";
+  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
+                                         EmbeddingLevel Level) const {
+  if (!Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
+
+  for (const Function &F : M.getFunctionDefs())
+    writeEmbeddingsToStream(F, OS, Level);
+}
+
+void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+                                         EmbeddingLevel Level) const {
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
+  if (F.isDeclaration()) {
+    OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+    return;
+  }
+
+  // Create embedder for this function
+  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for function " << F.getName() << "\n";
+    return;
+  }
+
+  OS << "Function: " << F.getName() << "\n";
+
+  // Generate embeddings based on the specified level
+  switch (Level) {
+  case FunctionLevel:
+    Emb->getFunctionVector().print(OS);
+    break;
+  case BasicBlockLevel:
+    for (const BasicBlock &BB : F) {
+      OS << BB.getName() << ":";
+      Emb->getBBVector(BB).print(OS);
+    }
+    break;
+  case InstructionLevel:
+    for (const Instruction &I : instructions(F)) {
+      OS << I;
+      Emb->getInstVector(I).print(OS);
+    }
+    break;
+  }
+}
+
+} // namespace ir2vec
+
+namespace mir2vec {
+
+bool MIR2VecTool::initializeVocabulary(const Module &M) {
+  MIR2VecVocabProvider Provider(MMI);
+  auto VocabOrErr = Provider.getVocabulary(M);
+  if (!VocabOrErr) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to load MIR2Vec vocabulary - "
+        << toString(VocabOrErr.takeError()) << "\n";
+    return false;
+  }
+  Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+  return true;
+}
+
+bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF)
+      continue;
+
+    const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+    const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+    const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+    auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+    if (!VocabOrErr) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create dummy vocabulary - "
+          << toString(VocabOrErr.takeError()) << "\n";
+      return false;
+    }
+    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+    return true;
+  }
+
+  WithColor::error(errs(), ToolName)
+      << "No machine functions found to initialize vocabulary\n";
+  return false;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
+  TripletResult Result;
+  Result.MaxRelation = MIRNextRelation;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "MIR Vocabulary must be initialized for triplet generation.\n";
+    return Result;
+  }
+
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions
+      if (MI.isDebugInstr())
+        continue;
+
+      // Get opcode entity ID
+      unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+      // Add "Next" relationship with previous instruction
+      if (HasPrevOpcode) {
+        Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+        LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
+                          << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+      }
+
+      // Add "Arg" relationships for operands
+      unsigned ArgIndex = 0;
+      for (const MachineOperand &MO : MI.operands()) {
+        auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+        unsigned RelationID = MIRArgRelation + ArgIndex;
+        Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+        LLVM_DEBUG({
+          std::string OperandStr = Vocab->getStringKey(OperandID);
+          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
+                 << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+
+      // Update MaxRelation if there were operands
+      if (ArgIndex > 0)
+        Result.MaxRelation =
+            std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+      PrevOpcode = OpcodeID;
+      HasPrevOpcode = true;
+    }
+  }
+
+  return Result;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
+  TripletResult Result;
+  Result.MaxRelation = MIRNextRelation;
+
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    TripletResult FuncResult = generateTriplets(*MF);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                           FuncResult.Triplets.end());
+  }
+
+  return Result;
+}
+
+void MIR2VecTool::writeTripletsToStream(const Module &M,
+                                        raw_ostream &OS) const {
+  auto Result = generateTriplets(M);
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+  for (const auto &T : Result.Triplets)
+    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList MIR2VecTool::collectEntityMappings() const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary must be initialized for entity mappings.\n";
+    return {};
+  }
+
+  const unsigned EntityCount = Vocab->getCanonicalSize();
+  EntityList Result;
+  for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+    Result.push_back(Vocab->getStringKey(EntityID));
+
+  return Result;
+}
+
+void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
+  auto Entities = collectEntityMappings();
+  if (Entities.empty())
+    return;
+
+  OS << Entities.size() << "\n";
+  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+                                          EmbeddingLevel Level) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    writeEmbeddingsToStream(*MF, OS, Level);
+  }
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+                                          EmbeddingLevel Level) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+  // Generate embeddings based on the specified level
+  switch (Level) {
+  case FunctionLevel:
+    OS << "Function vector: ";
+    Emb->getMFunctionVector().print(OS);
+    break;
+  case BasicBlockLevel:
+    OS << "Basic block vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      OS << "MBB " << MBB.getName() << ": ";
+      Emb->getMBBVector(MBB).print(OS);
+    }
+    break;
+  case InstructionLevel:
+    OS << "Instruction vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        OS << MI << " -> ";
+        Emb->getMInstVector(MI).print(OS);
+      }
+    }
+    break;
+  }
+}
+
+} // namespace mir2vec
+
+} // namespace llvm
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/emb-tool.h
similarity index 98%
rename from llvm/tools/llvm-ir2vec/llvm-ir2vec.h
rename to llvm/tools/llvm-ir2vec/emb-tool.h
index 566c362edbd22..009bcec60108b 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/emb-tool.h
@@ -12,8 +12,8 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
-#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
+#define LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
@@ -44,7 +44,7 @@
 #define DEBUG_TYPE "ir2vec"
 
 namespace llvm {
-
+  
 /// Tool name for error reporting
 static const char *ToolName = "llvm-ir2vec";
 
@@ -198,4 +198,4 @@ struct MIRContext {
 
 } // namespace llvm
 
-#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
\ No newline at end of file
+#endif // LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 6b70e09518fa7..a2b2f4e6a7aa8 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,7 +54,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#include "llvm-ir2vec.h"
+#include "emb-tool.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"
@@ -147,167 +147,6 @@ static cl::opt<EmbeddingLevel>
 
 namespace ir2vec {
 
-bool IR2VecTool::initializeVocabulary() {
-  // Register and run the IR2Vec vocabulary analysis
-  // The vocabulary file path is specified via --ir2vec-vocab-path global
-  // option
-  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
-  MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
-  // This will throw an error if vocab is not found or invalid
-  Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
-  return Vocab->isValid();
-}
-
-TripletResult IR2VecTool::generateTriplets(const Function &F) const {
-  if (F.isDeclaration())
-    return {};
-
-  TripletResult Result;
-  Result.MaxRelation = 0;
-
-  unsigned MaxRelation = NextRelation;
-  unsigned PrevOpcode = 0;
-  bool HasPrevOpcode = false;
-
-  for (const BasicBlock &BB : F) {
-    for (const auto &I : BB.instructionsWithoutDebug()) {
-      unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
-      unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
-      // Add "Next" relationship with previous instruction
-      if (HasPrevOpcode) {
-        Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
-        LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
-                          << '\t'
-                          << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
-                          << '\t' << "Next\n");
-      }
-
-      // Add "Type" relationship
-      Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
-      LLVM_DEBUG(
-          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                 << '\t' << "Type\n");
-
-      // Add "Arg" relationships
-      unsigned ArgIndex = 0;
-      for (const Use &U : I.operands()) {
-        unsigned OperandID = Vocabulary::getIndex(*U.get());
-        unsigned RelationID = ArgRelation + ArgIndex;
-        Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
-        LLVM_DEBUG({
-          StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
-              Vocabulary::getOperandKind(U.get()));
-          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
-        });
-
-        ++ArgIndex;
-      }
-      // Only update MaxRelation if there were operands
-      if (ArgIndex > 0)
-        MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-      PrevOpcode = Opcode;
-      HasPrevOpcode = true;
-    }
-  }
-
-  Result.MaxRelation = MaxRelation;
-  return Result;
-}
-
-TripletResult IR2VecTool::generateTriplets() const {
-  TripletResult Result;
-  Result.MaxRelation = NextRelation;
-
-  for (const Function &F : M.getFunctionDefs()) {
-    TripletResult FuncResult = generateTriplets(F);
-    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                           FuncResult.Triplets.end());
-  }
-
-  return Result;
-}
-
-void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
-  auto Result = generateTriplets();
-  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-  for (const auto &T : Result.Triplets)
-    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-}
-
-EntityList IR2VecTool::collectEntityMappings() {
-  auto EntityLen = Vocabulary::getCanonicalSize();
-  EntityList Result;
-  for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-    Result.push_back(Vocabulary::getStringKey(EntityID).str());
-  return Result;
-}
-
-void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
-  auto Entities = collectEntityMappings();
-  OS << Entities.size() << "\n";
-  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-    OS << Entities[EntityID] << '\t' << EntityID << '\n';
-}
-
-void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
-                                         EmbeddingLevel Level) const {
-  if (!Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
-
-  for (const Function &F : M.getFunctionDefs())
-    writeEmbeddingsToStream(F, OS, Level);
-}
-
-void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
-                                         EmbeddingLevel Level) const {
-  if (!Vocab || !Vocab->isValid()) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-    return;
-  }
-  if (F.isDeclaration()) {
-    OS << "Function " << F.getName() << " is a declaration, skipping.\n";
-    return;
-  }
-
-  // Create embedder for this function
-  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-  if (!Emb) {
-    WithColor::error(errs(), ToolName)
-        << "Failed to create embedder for function " << F.getName() << "\n";
-    return;
-  }
-
-  OS << "Function: " << F.getName() << "\n";
-
-  // Generate embeddings based on the specified level
-  switch (Level) {
-  case FunctionLevel:
-    Emb->getFunctionVector().print(OS);
-    break;
-  case BasicBlockLevel:
-    for (const BasicBlock &BB : F) {
-      OS << BB.getName() << ":";
-      Emb->getBBVector(BB).print(OS);
-    }
-    break;
-  case InstructionLevel:
-    for (const Instruction &I : instructions(F)) {
-      OS << I;
-      Emb->getInstVector(I).print(OS);
-    }
-    break;
-  }
-}
-
 /// Process the module and generate output based on selected subcommand
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
@@ -341,215 +180,6 @@ Error processModule(Module &M, raw_ostream &OS) {
 
 namespace mir2vec {
 
-bool MIR2VecTool::initializeVocabulary(const Module &M) {
-  MIR2VecVocabProvider Provider(MMI);
-  auto VocabOrErr = Provider.getVocabulary(M);
-  if (!VocabOrErr) {
-    WithColor::error(errs(), ToolName)
-        << "Failed to load MIR2Vec vocabulary - "
-        << toString(VocabOrErr.takeError()) << "\n";
-    return false;
-  }
-  Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-  return true;
-}
-
-bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
-  for (const Function &F : M.getFunctionDefs()) {
-    MachineFunction *MF = MMI.getMachineFunction(F);
-    if (!MF)
-      continue;
-
-    const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
-    const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
-    const MachineRegisterInfo &MRI = MF->getRegInfo();
-
-    auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
-    if (!VocabOrErr) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create dummy vocabulary - "
-          << toString(VocabOrErr.takeError()) << "\n";
-      return false;
-    }
-    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-    return true;
-  }
-
-  WithColor::error(errs(), ToolName)
-      << "No machine functions found to initialize vocabulary\n";
-  return false;
-}
-
-TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
-  TripletResult Result;
-  Result.MaxRelation = MIRNextRelation;
-
-  if (!Vocab) {
-    WithColor::error(errs(), ToolName)
-        << "MIR Vocabulary must be initialized for triplet generation.\n";
-    return Result;
-  }
-
-  unsigned PrevOpcode = 0;
-  bool HasPrevOpcode = false;
-  for (const MachineBasicBlock &MBB : MF) {
-    for (const MachineInstr &MI : MBB) {
-      // Skip debug instructions
-      if (MI.isDebugInstr())
-        continue;
-
-      // Get opcode entity ID
-      unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
-      // Add "Next" relationship with previous instruction
-      if (HasPrevOpcode) {
-        Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
-        LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
-                          << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
-      }
-
-      // Add "Arg" relationships for operands
-      unsigned ArgIndex = 0;
-      for (const MachineOperand &MO : MI.operands()) {
-        auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
-        unsigned RelationID = MIRArgRelation + ArgIndex;
-        Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
-        LLVM_DEBUG({
-          std::string OperandStr = Vocab->getStringKey(OperandID);
-          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
-                 << "Arg" << ArgIndex << '\n';
-        });
-
-        ++ArgIndex;
-      }
-
-      // Update MaxRelation if there were operands
-      if (ArgIndex > 0)
-        Result.MaxRelation =
-            std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
-      PrevOpcode = OpcodeID;
-      HasPrevOpcode = true;
-    }
-  }
-
-  return Result;
-}
-
-TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
-  TripletResult Result;
-  Result.MaxRelation = MIRNextRelation;
-
-  for (const Function &F : M.getFunctionDefs()) {
-    MachineFunction *MF = MMI.getMachineFunction(F);
-    if (!MF) {
-      WithColor::warning(errs(), ToolName)
-          << "No MachineFunction for " << F.getName() << "\n";
-      continue;
-    }
-
-    TripletResult FuncResult = generateTriplets(*MF);
-    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                           FuncResult.Triplets.end());
-  }
-
-  return Result;
-}
-
-void MIR2VecTool::writeTripletsToStream(const Module &M,
-                                        raw_ostream &OS) const {
-  auto Result = generateTriplets(M);
-  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-  for (const auto &T : Result.Triplets)
-    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-}
-
-EntityList MIR2VecTool::collectEntityMappings() const {
-  if (!Vocab) {
-    WithColor::error(errs(), ToolName)
-        << "Vocabulary must be initialized for entity mappings.\n";
-    return {};
-  }
-
-  const unsigned EntityCount = Vocab->getCanonicalSize();
-  EntityList Result;
-  for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-    Result.push_back(Vocab->getStringKey(EntityID));
-
-  return Result;
-}
-
-void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
-  auto Entities = collectEntityMappings();
-  if (Entities.empty())
-    return;
-
-  OS << Entities.size() << "\n";
-  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-    OS << Entities[EntityID] << '\t' << EntityID << '\n';
-}
-
-void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
-                                          EmbeddingLevel Level) const {
-  if (!Vocab) {
-    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-    return;
-  }
-
-  for (const Function &F : M.getFunctionDefs()) {
-    MachineFunction *MF = MMI.getMachineFunction(F);
-    if (!MF) {
-      WithColor::warning(errs(), ToolName)
-          << "No MachineFunction for " << F.getName() << "\n";
-      continue;
-    }
-
-    writeEmbeddingsToStream(*MF, OS, Level);
-  }
-}
-
-void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
-                                          EmbeddingLevel Level) const {
-  if (!Vocab) {
-    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-    return;
-  }
-
-  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-  if (!Emb) {
-    WithColor::error(errs(), ToolName)
-        << "Failed to create embedder for " << MF.getName() << "\n";
-    return;
-  }
-
-  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
-
-  // Generate embeddings based on the specified level
-  switch (Level) {
-  case FunctionLevel:
-    OS << "Function vector: ";
-    Emb->getMFunctionVector().print(OS);
-    break;
-  case BasicBlockLevel:
-    OS << "Basic block vectors:\n";
-    for (const MachineBasicBlock &MBB : MF) {
-      OS << "MBB " << MBB.getName() << ": ";
-      Emb->getMBBVector(MBB).print(OS);
-    }
-    break;
-  case InstructionLevel:
-    OS << "Instruction vectors:\n";
-    for (const MachineBasicBlock &MBB : MF) {
-      for (const MachineInstr &MI : MBB) {
-        OS << MI << " -> ";
-        Emb->getMInstVector(MI).print(OS);
-      }
-    }
-    break;
-  }
-}
-
 /// Setup MIR context from input file
 Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
   SMDiagnostic Err;

>From ac7cd83a0400a819716fb70d54d3ae399fa206dc Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Sun, 21 Dec 2025 22:32:35 +0530
Subject: [PATCH 2/5] nit commit - code formatting fixup

---
 llvm/tools/llvm-ir2vec/emb-tool.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/tools/llvm-ir2vec/emb-tool.h b/llvm/tools/llvm-ir2vec/emb-tool.h
index 009bcec60108b..24b8022ed7482 100644
--- a/llvm/tools/llvm-ir2vec/emb-tool.h
+++ b/llvm/tools/llvm-ir2vec/emb-tool.h
@@ -44,7 +44,7 @@
 #define DEBUG_TYPE "ir2vec"
 
 namespace llvm {
-  
+
 /// Tool name for error reporting
 static const char *ToolName = "llvm-ir2vec";
 

>From 959ab6047c0405d57073000b4c814e0d1840d10b Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 23 Dec 2025 16:00:00 +0530
Subject: [PATCH 3/5] Work Commit. Creating a lib/ structure, and using it for
 the first draft of the python bindings

---
 llvm/tools/llvm-ir2vec/CMakeLists.txt         |  29 ++---
 .../tools/llvm-ir2vec/bindings/CMakeLists.txt |  14 +++
 .../llvm-ir2vec/bindings/ir2vec_bindings.cpp  | 107 ++++++++++++++++++
 llvm/tools/llvm-ir2vec/lib/CMakeLists.txt     |  28 +++++
 llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.cpp |   0
 llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.h   |   0
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        |   2 +-
 7 files changed, 159 insertions(+), 21 deletions(-)
 create mode 100644 llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
 create mode 100644 llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
 create mode 100644 llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
 rename llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.cpp (100%)
 rename llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.h (100%)

diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index 9d5db8663fb38..d5956cde6892d 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -1,26 +1,15 @@
-set(LLVM_LINK_COMPONENTS
-  # Core LLVM components for IR processing
-  Analysis
-  Core
-  IRReader
-  Support
-  
-  # Machine IR components (for -mode=mir)
-  CodeGen           
-  MIRParser         
-  
-  # Target initialization (required for MIR parsing)
-  AllTargetsAsmParsers
-  AllTargetsCodeGens
-  AllTargetsDescs
-  AllTargetsInfos
-  TargetParser
-  )
+# Build the library first
+add_subdirectory(lib)
 
+# Build the llvm-ir2vec executable
 add_llvm_tool(llvm-ir2vec
   llvm-ir2vec.cpp
-  emb-tool.cpp
   
   DEPENDS
   intrinsics_gen
-  )
+)
+
+# Link the executable against the library
+target_link_libraries(llvm-ir2vec PRIVATE EmbTool)
+
+add_subdirectory(bindings)
diff --git a/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
new file mode 100644
index 0000000000000..86f4ffb35b3ca
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
@@ -0,0 +1,14 @@
+find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
+find_package(pybind11 CONFIG QUIET)
+
+if(NOT pybind11_FOUND)
+  message(WARNING "pybind11 not found - skipping py_ir2vec")
+  return()
+endif()
+
+pybind11_add_module(py_ir2vec
+  ir2vec_bindings.cpp
+)
+
+# Link against the EmbTool library
+target_link_libraries(py_ir2vec PRIVATE EmbTool)
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
new file mode 100644
index 0000000000000..d9d4b9f07e637
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -0,0 +1,107 @@
+//===- ir2vec_bindings.cpp - Python Bindings for IR2Vec ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../lib/emb-tool.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace llvm::ir2vec;
+
+namespace {
+
+bool fileNotValid(const std::string &Filename) {
+  std::ifstream F(Filename, std::ios_base::in | std::ios_base::binary);
+  return !F.good();
+}
+
+std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
+                                  LLVMContext &Context) {
+  SMDiagnostic Err;
+  auto M = parseIRFile(Filename, Err, Context);
+  if (!M) {
+    Err.print(Filename.c_str(), outs());
+    throw std::runtime_error("Failed to parse IR file.");
+  }
+  return M;
+}
+
+class PyIR2VecTool {
+private:
+  std::unique_ptr<LLVMContext> Ctx;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<IR2VecTool> Tool;
+
+public:
+  PyIR2VecTool(const std::string &Filename) {
+    if (fileNotValid(Filename))
+      throw std::runtime_error("Invalid file path");
+
+    Ctx = std::make_unique<LLVMContext>();
+    M = getLLVMIR(Filename, *Ctx);
+    Tool = std::make_unique<IR2VecTool>(*M);
+  }
+
+  // Get entity mappings (vocabulary)
+  std::vector<std::string> getEntityMappings() {
+    return IR2VecTool::collectEntityMappings();
+  }
+
+  // Generate triplets for vocabulary training
+  py::dict generateTriplets() {
+    auto result = Tool->generateTriplets();
+    py::list triplets_list;
+
+    for (const auto &t : result.Triplets) {
+      triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
+    }
+
+    py::dict output;
+    output["max_relation"] = result.MaxRelation;
+    output["triplets"] = triplets_list;
+    return output;
+  }
+};
+
+} // anonymous namespace
+
+PYBIND11_MODULE(py_ir2vec, m) {
+  m.doc() = std::string("Python bindings for ") + ToolName + "\n\n" + ToolName +
+            " provides distributed representations for LLVM IR.\n\n"
+            "Example:\n"
+            "    >>> import py_ir2vec\n"
+            "    >>> tool = py_ir2vec." +
+            ToolName +
+            "Tool(\"test.ll\")\n"
+            "    >>> entities = tool.getEntityMappings()\n"
+            "    >>> triplets = tool.generateTriplets()";
+
+  // Main tool class
+  py::class_<PyIR2VecTool>(m, "IR2VecTool")
+      .def(py::init<const std::string &>(), py::arg("filename"),
+           "Initialize IR2Vec on an LLVM IR file\n"
+           "Args:\n"
+           "  filename: Path to LLVM IR file (.bc or .ll)")
+      .def("getEntityMappings", &PyIR2VecTool::getEntityMappings,
+           "Get entity mappings (vocabulary)\n"
+           "Returns: list[str] - list of entity names where index is entity_id")
+      .def("generateTriplets", &PyIR2VecTool::generateTriplets,
+           "Generate triplets for vocabulary training\n"
+           "Returns: dict with 'max_relation' and 'triplets' keys");
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt b/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
new file mode 100644
index 0000000000000..f15761beb97df
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
@@ -0,0 +1,28 @@
+set(LLVM_LINK_COMPONENTS
+  # Core LLVM components for IR processing
+  Analysis
+  Core
+  IRReader
+  Support
+  
+  # Machine IR components (for -mode=mir)
+  CodeGen           
+  MIRParser         
+  
+  # Target initialization (required for MIR parsing)
+  AllTargetsAsmParsers
+  AllTargetsCodeGens
+  AllTargetsDescs
+  AllTargetsInfos
+  TargetParser
+)
+
+add_llvm_library(EmbTool
+  emb-tool.cpp
+  
+  LINK_COMPONENTS
+  ${LLVM_LINK_COMPONENTS}
+  
+  DEPENDS
+  intrinsics_gen
+)
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.cpp b/llvm/tools/llvm-ir2vec/lib/emb-tool.cpp
similarity index 100%
rename from llvm/tools/llvm-ir2vec/emb-tool.cpp
rename to llvm/tools/llvm-ir2vec/lib/emb-tool.cpp
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.h b/llvm/tools/llvm-ir2vec/lib/emb-tool.h
similarity index 100%
rename from llvm/tools/llvm-ir2vec/emb-tool.h
rename to llvm/tools/llvm-ir2vec/lib/emb-tool.h
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a2b2f4e6a7aa8..541585d881638 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,7 +54,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#include "emb-tool.h"
+#include "lib/emb-tool.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"

>From 7e0cc13677dd52f6843f6d586fde92704366fc2a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 24 Dec 2025 00:38:48 +0530
Subject: [PATCH 4/5] Work commit - python bindings reworked to include a
 compulsory vocab override input. Added getTriplets function as well

---
 llvm/include/llvm/Analysis/IR2Vec.h           |   2 +-
 llvm/lib/Analysis/IR2Vec.cpp                  |  27 +++-
 .../llvm-ir2vec/bindings/ir2vec_bindings.cpp  | 116 ++++++++++++------
 3 files changed, 99 insertions(+), 46 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 5957a3743f22e..585dffe715be2 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -615,7 +615,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
   using VocabMap = std::map<std::string, ir2vec::Embedding>;
   std::optional<ir2vec::VocabStorage> Vocab;
 
-  Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
+  Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab, VocabMap &TypeVocab,
                        VocabMap &ArgVocab);
   void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
                             VocabMap &ArgVocab);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 85b5372c961c1..a47ced94ee676 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -62,6 +62,17 @@ cl::opt<IR2VecKind> IR2VecEmbeddingKind(
     cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
     cl::cat(IR2VecCategory));
 
+static std::optional<std::string> VocabOverride;
+void setIR2VecVocabPath(StringRef Path) {
+  if (Path.empty()) VocabOverride = std::nullopt;
+  else VocabOverride = Path.str();
+}
+
+StringRef getIR2VecVocabPath() {
+  return VocabOverride ? StringRef(*VocabOverride)
+                        : StringRef(VocabFile.getValue());
+}
+
 } // namespace ir2vec
 } // namespace llvm
 
@@ -482,12 +493,13 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
 
 // FIXME: Make this optional. We can avoid file reads
 // by auto-generating a default vocabulary during the build time.
-Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
+Error IR2VecVocabAnalysis::readVocabulary(StringRef EffectivePath,
+                                          VocabMap &OpcVocab,
                                           VocabMap &TypeVocab,
                                           VocabMap &ArgVocab) {
-  auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
+  auto BufOrError = MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
   if (!BufOrError)
-    return createFileError(VocabFile, BufOrError.getError());
+    return createFileError(EffectivePath.str(), BufOrError.getError());
 
   auto Content = BufOrError.get()->getBuffer();
 
@@ -615,8 +627,11 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
   if (Vocab.has_value())
     return Vocabulary(std::move(Vocab.value()));
 
+  StringRef EffectivePath =
+    VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
+
   // Otherwise, try to read from the vocabulary file.
-  if (VocabFile.empty()) {
+  if (EffectivePath.empty()) {
     // FIXME: Use default vocabulary
     Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
                    "set it using --ir2vec-vocab-path");
@@ -624,7 +639,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
   }
 
   VocabMap OpcVocab, TypeVocab, ArgVocab;
-  if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
+  if (auto Err = readVocabulary(EffectivePath, OpcVocab, TypeVocab, ArgVocab)) {
     emitError(std::move(Err), *Ctx);
     return Vocabulary();
   }
@@ -694,4 +709,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
     Entry.print(OS);
   }
   return PreservedAnalyses::all();
-}
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
index d9d4b9f07e637..b4658fc60a3ed 100644
--- a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -15,6 +15,7 @@
 
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
+#include <pybind11/numpy.h>
 
 #include <fstream>
 #include <memory>
@@ -24,6 +25,13 @@ namespace py = pybind11;
 using namespace llvm;
 using namespace llvm::ir2vec;
 
+namespace llvm {
+namespace ir2vec {
+void setIR2VecVocabPath(StringRef Path);
+StringRef getIR2VecVocabPath();
+} // namespace ir2vec
+} // namespace llvm
+
 namespace {
 
 bool fileNotValid(const std::string &Filename) {
@@ -32,7 +40,7 @@ bool fileNotValid(const std::string &Filename) {
 }
 
 std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
-                                  LLVMContext &Context) {
+                                   LLVMContext &Context) {
   SMDiagnostic Err;
   auto M = parseIRFile(Filename, Err, Context);
   if (!M) {
@@ -49,59 +57,89 @@ class PyIR2VecTool {
   std::unique_ptr<IR2VecTool> Tool;
 
 public:
-  PyIR2VecTool(const std::string &Filename) {
-    if (fileNotValid(Filename))
-      throw std::runtime_error("Invalid file path");
+  PyIR2VecTool(std::string Filename, std::string Mode,
+								std::string VocabOverride) {
+		if (fileNotValid(Filename))
+			throw std::runtime_error("Invalid file path");
 
-    Ctx = std::make_unique<LLVMContext>();
-    M = getLLVMIR(Filename, *Ctx);
-    Tool = std::make_unique<IR2VecTool>(*M);
-  }
+		if (Mode != "sym" && Mode != "fa")
+			throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
 
-  // Get entity mappings (vocabulary)
-  std::vector<std::string> getEntityMappings() {
-    return IR2VecTool::collectEntityMappings();
+		if (VocabOverride.empty())
+			throw std::runtime_error("Error - Empty Vocab Path not allowed");
+
+		setIR2VecVocabPath(VocabOverride);
+
+		Ctx = std::make_unique<LLVMContext>();
+		M = getLLVMIR(Filename, *Ctx);
+		Tool = std::make_unique<IR2VecTool>(*M);
+
+		bool Ok = Tool->initializeVocabulary();
+		if (!Ok)
+			throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
   }
 
-  // Generate triplets for vocabulary training
   py::dict generateTriplets() {
     auto result = Tool->generateTriplets();
     py::list triplets_list;
-
-    for (const auto &t : result.Triplets) {
+    for (const auto& t : result.Triplets) {
       triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
     }
 
-    py::dict output;
-    output["max_relation"] = result.MaxRelation;
-    output["triplets"] = triplets_list;
-    return output;
+    return py::dict(
+      py::arg("max_relation") = result.MaxRelation,
+      py::arg("triplets") = triplets_list
+    );
+  }
+
+  EntityList collectEntityMappings() {
+    return IR2VecTool::collectEntityMappings();
   }
 };
 
-} // anonymous namespace
+} // namespace
 
 PYBIND11_MODULE(py_ir2vec, m) {
-  m.doc() = std::string("Python bindings for ") + ToolName + "\n\n" + ToolName +
-            " provides distributed representations for LLVM IR.\n\n"
-            "Example:\n"
-            "    >>> import py_ir2vec\n"
-            "    >>> tool = py_ir2vec." +
-            ToolName +
-            "Tool(\"test.ll\")\n"
-            "    >>> entities = tool.getEntityMappings()\n"
-            "    >>> triplets = tool.generateTriplets()";
+  m.doc() = R"pbdoc(
+  Python bindings for LLVM IR2Vec
+
+  IR2Vec provides distributed representations for LLVM IR.
+  Example:
+    >>> import py_ir2vec as M
+    >>> tool = M.initEmbedding(
+      filename=tesy.ll,
+      mode="sym",
+      vocab_override=vocab_path
+    )
+  )pbdoc";
 
   // Main tool class
   py::class_<PyIR2VecTool>(m, "IR2VecTool")
-      .def(py::init<const std::string &>(), py::arg("filename"),
-           "Initialize IR2Vec on an LLVM IR file\n"
-           "Args:\n"
-           "  filename: Path to LLVM IR file (.bc or .ll)")
-      .def("getEntityMappings", &PyIR2VecTool::getEntityMappings,
-           "Get entity mappings (vocabulary)\n"
-           "Returns: list[str] - list of entity names where index is entity_id")
-      .def("generateTriplets", &PyIR2VecTool::generateTriplets,
-           "Generate triplets for vocabulary training\n"
-           "Returns: dict with 'max_relation' and 'triplets' keys");
-}
\ No newline at end of file
+		.def(py::init<std::string, std::string, std::string>(),
+			py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
+			"Initialize IR2Vec on an LLVM IR file\n"
+			"Args:\n"
+			"  filename: Path to LLVM IR file (.bc or .ll)\n"
+			"  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+			"  vocab_override: Path to vocabulary file")
+		.def("generateTriplets", &PyIR2VecTool::generateTriplets,
+			"Generate triplets for vocabulary training\n"
+			"Returns: TripletResult Dict with max_relation and list of triplets")
+    .def("getEntityMappings", &PyIR2VecTool::collectEntityMappings,
+			"Get entity mappings (vocabulary)\n"
+			"Returns: list[str] - list of entity names where index is entity_id");
+
+  m.def(
+		"initEmbedding",
+		[](std::string filename, std::string mode, std::string vocab_override) {
+			return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
+		},
+		py::arg("filename"), py::arg("mode") = "sym",
+		py::arg("vocab_override"),
+		"Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
+		"Args:\n"
+		"  filename: Path to LLVM IR file (.bc or .ll)\n"
+		"  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+		"  vocab_override: Path to vocabulary file\n"
+		"Returns: IR2VecTool instance");
+}

>From d6cad46dc84d6c85cc42b64c73df8bf90652b56d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 24 Dec 2025 00:48:15 +0530
Subject: [PATCH 5/5] Nit commit - code formatting fixup

---
 llvm/include/llvm/Analysis/IR2Vec.h           |  4 +-
 llvm/lib/Analysis/IR2Vec.cpp                  | 13 +--
 .../llvm-ir2vec/bindings/ir2vec_bindings.cpp  | 90 +++++++++----------
 3 files changed, 54 insertions(+), 53 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 585dffe715be2..dcb9783bfa91c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -615,8 +615,8 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
   using VocabMap = std::map<std::string, ir2vec::Embedding>;
   std::optional<ir2vec::VocabStorage> Vocab;
 
-  Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab, VocabMap &TypeVocab,
-                       VocabMap &ArgVocab);
+  Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab,
+                       VocabMap &TypeVocab, VocabMap &ArgVocab);
   void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
                             VocabMap &ArgVocab);
   void emitError(Error Err, LLVMContext &Ctx);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index a47ced94ee676..16822a6828d8c 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -64,13 +64,15 @@ cl::opt<IR2VecKind> IR2VecEmbeddingKind(
 
 static std::optional<std::string> VocabOverride;
 void setIR2VecVocabPath(StringRef Path) {
-  if (Path.empty()) VocabOverride = std::nullopt;
-  else VocabOverride = Path.str();
+  if (Path.empty())
+    VocabOverride = std::nullopt;
+  else
+    VocabOverride = Path.str();
 }
 
 StringRef getIR2VecVocabPath() {
   return VocabOverride ? StringRef(*VocabOverride)
-                        : StringRef(VocabFile.getValue());
+                       : StringRef(VocabFile.getValue());
 }
 
 } // namespace ir2vec
@@ -497,7 +499,8 @@ Error IR2VecVocabAnalysis::readVocabulary(StringRef EffectivePath,
                                           VocabMap &OpcVocab,
                                           VocabMap &TypeVocab,
                                           VocabMap &ArgVocab) {
-  auto BufOrError = MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
+  auto BufOrError =
+      MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
   if (!BufOrError)
     return createFileError(EffectivePath.str(), BufOrError.getError());
 
@@ -628,7 +631,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
     return Vocabulary(std::move(Vocab.value()));
 
   StringRef EffectivePath =
-    VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
+      VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
 
   // Otherwise, try to read from the vocabulary file.
   if (EffectivePath.empty()) {
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
index b4658fc60a3ed..784c4f0fe37dd 100644
--- a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -13,9 +13,9 @@
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include <pybind11/numpy.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
-#include <pybind11/numpy.h>
 
 #include <fstream>
 #include <memory>
@@ -40,7 +40,7 @@ bool fileNotValid(const std::string &Filename) {
 }
 
 std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
-                                   LLVMContext &Context) {
+                                  LLVMContext &Context) {
   SMDiagnostic Err;
   auto M = parseIRFile(Filename, Err, Context);
   if (!M) {
@@ -58,38 +58,36 @@ class PyIR2VecTool {
 
 public:
   PyIR2VecTool(std::string Filename, std::string Mode,
-								std::string VocabOverride) {
-		if (fileNotValid(Filename))
-			throw std::runtime_error("Invalid file path");
+               std::string VocabOverride) {
+    if (fileNotValid(Filename))
+      throw std::runtime_error("Invalid file path");
 
-		if (Mode != "sym" && Mode != "fa")
-			throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
+    if (Mode != "sym" && Mode != "fa")
+      throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
 
-		if (VocabOverride.empty())
-			throw std::runtime_error("Error - Empty Vocab Path not allowed");
+    if (VocabOverride.empty())
+      throw std::runtime_error("Error - Empty Vocab Path not allowed");
 
-		setIR2VecVocabPath(VocabOverride);
+    setIR2VecVocabPath(VocabOverride);
 
-		Ctx = std::make_unique<LLVMContext>();
-		M = getLLVMIR(Filename, *Ctx);
-		Tool = std::make_unique<IR2VecTool>(*M);
+    Ctx = std::make_unique<LLVMContext>();
+    M = getLLVMIR(Filename, *Ctx);
+    Tool = std::make_unique<IR2VecTool>(*M);
 
-		bool Ok = Tool->initializeVocabulary();
-		if (!Ok)
-			throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
+    bool Ok = Tool->initializeVocabulary();
+    if (!Ok)
+      throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
   }
 
   py::dict generateTriplets() {
     auto result = Tool->generateTriplets();
     py::list triplets_list;
-    for (const auto& t : result.Triplets) {
+    for (const auto &t : result.Triplets) {
       triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
     }
 
-    return py::dict(
-      py::arg("max_relation") = result.MaxRelation,
-      py::arg("triplets") = triplets_list
-    );
+    return py::dict(py::arg("max_relation") = result.MaxRelation,
+                    py::arg("triplets") = triplets_list);
   }
 
   EntityList collectEntityMappings() {
@@ -115,31 +113,31 @@ PYBIND11_MODULE(py_ir2vec, m) {
 
   // Main tool class
   py::class_<PyIR2VecTool>(m, "IR2VecTool")
-		.def(py::init<std::string, std::string, std::string>(),
-			py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
-			"Initialize IR2Vec on an LLVM IR file\n"
-			"Args:\n"
-			"  filename: Path to LLVM IR file (.bc or .ll)\n"
-			"  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
-			"  vocab_override: Path to vocabulary file")
-		.def("generateTriplets", &PyIR2VecTool::generateTriplets,
-			"Generate triplets for vocabulary training\n"
-			"Returns: TripletResult Dict with max_relation and list of triplets")
-    .def("getEntityMappings", &PyIR2VecTool::collectEntityMappings,
-			"Get entity mappings (vocabulary)\n"
-			"Returns: list[str] - list of entity names where index is entity_id");
+      .def(py::init<std::string, std::string, std::string>(),
+           py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
+           "Initialize IR2Vec on an LLVM IR file\n"
+           "Args:\n"
+           "  filename: Path to LLVM IR file (.bc or .ll)\n"
+           "  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+           "  vocab_override: Path to vocabulary file")
+      .def("generateTriplets", &PyIR2VecTool::generateTriplets,
+           "Generate triplets for vocabulary training\n"
+           "Returns: TripletResult Dict with max_relation and list of triplets")
+      .def(
+          "getEntityMappings", &PyIR2VecTool::collectEntityMappings,
+          "Get entity mappings (vocabulary)\n"
+          "Returns: list[str] - list of entity names where index is entity_id");
 
   m.def(
-		"initEmbedding",
-		[](std::string filename, std::string mode, std::string vocab_override) {
-			return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
-		},
-		py::arg("filename"), py::arg("mode") = "sym",
-		py::arg("vocab_override"),
-		"Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
-		"Args:\n"
-		"  filename: Path to LLVM IR file (.bc or .ll)\n"
-		"  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
-		"  vocab_override: Path to vocabulary file\n"
-		"Returns: IR2VecTool instance");
+      "initEmbedding",
+      [](std::string filename, std::string mode, std::string vocab_override) {
+        return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
+      },
+      py::arg("filename"), py::arg("mode") = "sym", py::arg("vocab_override"),
+      "Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
+      "Args:\n"
+      "  filename: Path to LLVM IR file (.bc or .ll)\n"
+      "  mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+      "  vocab_override: Path to vocabulary file\n"
+      "Returns: IR2VecTool instance");
 }



More information about the llvm-commits mailing list