[llvm] [IR2Vec] Introducing python bindings for IR2Vec (PR #173194)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 23 11:18:32 PST 2025
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/173194
>From 958aafd1475e75351b38cd850ee5441b431c427d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Sun, 21 Dec 2025 01:25:33 +0530
Subject: [PATCH 1/5] Work Commit - Separating all tool implementation from cli
file
---
llvm/tools/llvm-ir2vec/CMakeLists.txt | 1 +
llvm/tools/llvm-ir2vec/emb-tool.cpp | 421 ++++++++++++++++++
.../llvm-ir2vec/{llvm-ir2vec.h => emb-tool.h} | 8 +-
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 372 +---------------
4 files changed, 427 insertions(+), 375 deletions(-)
create mode 100644 llvm/tools/llvm-ir2vec/emb-tool.cpp
rename llvm/tools/llvm-ir2vec/{llvm-ir2vec.h => emb-tool.h} (98%)
diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index 2bb6686392907..9d5db8663fb38 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -19,6 +19,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_tool(llvm-ir2vec
llvm-ir2vec.cpp
+ emb-tool.cpp
DEPENDS
intrinsics_gen
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.cpp b/llvm/tools/llvm-ir2vec/emb-tool.cpp
new file mode 100644
index 0000000000000..891b26f8ef763
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/emb-tool.cpp
@@ -0,0 +1,421 @@
+//===- emb-tool.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2VecTool and MIR2VecTool classes for
+/// IR2Vec/MIR2Vec embedding generation.
+///
+//===----------------------------------------------------------------------===//
+
+#include "emb-tool.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Target/TargetMachine.h"
+
+#define DEBUG_TYPE "ir2vec"
+
+namespace llvm {
+
+namespace ir2vec {
+
+bool IR2VecTool::initializeVocabulary() {
+ // Register and run the IR2Vec vocabulary analysis
+ // The vocabulary file path is specified via --ir2vec-vocab-path global
+ // option
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+ // This will throw an error if vocab is not found or invalid
+ Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+ return Vocab->isValid();
+}
+
+TripletResult IR2VecTool::generateTriplets(const Function &F) const {
+ if (F.isDeclaration())
+ return {};
+
+ TripletResult Result;
+ Result.MaxRelation = 0;
+
+ unsigned MaxRelation = NextRelation;
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
+ for (const BasicBlock &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+ LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
+ << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
+ << '\t' << "Next\n");
+ }
+
+ // Add "Type" relationship
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+ LLVM_DEBUG(
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
+
+ // Add "Arg" relationships
+ unsigned ArgIndex = 0;
+ for (const Use &U : I.operands()) {
+ unsigned OperandID = Vocabulary::getIndex(*U.get());
+ unsigned RelationID = ArgRelation + ArgIndex;
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+ LLVM_DEBUG({
+ StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+ Vocabulary::getOperandKind(U.get()));
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+ // Only update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+ PrevOpcode = Opcode;
+ HasPrevOpcode = true;
+ }
+ }
+
+ Result.MaxRelation = MaxRelation;
+ return Result;
+}
+
+TripletResult IR2VecTool::generateTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ TripletResult FuncResult = generateTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+}
+
+void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
+ auto Result = generateTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList IR2VecTool::collectEntityMappings() {
+ auto EntityLen = Vocabulary::getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+ return Result;
+}
+
+void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs())
+ writeEmbeddingsToStream(F, OS, Level);
+}
+
+void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+ if (F.isDeclaration()) {
+ OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+ return;
+ }
+
+ // Create embedder for this function
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for function " << F.getName() << "\n";
+ return;
+ }
+
+ OS << "Function: " << F.getName() << "\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ Emb->getFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ for (const BasicBlock &BB : F) {
+ OS << BB.getName() << ":";
+ Emb->getBBVector(BB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ for (const Instruction &I : instructions(F)) {
+ OS << I;
+ Emb->getInstVector(I).print(OS);
+ }
+ break;
+ }
+}
+
+} // namespace ir2vec
+
+namespace mir2vec {
+
+bool MIR2VecTool::initializeVocabulary(const Module &M) {
+ MIR2VecVocabProvider Provider(MMI);
+ auto VocabOrErr = Provider.getVocabulary(M);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to load MIR2Vec vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+}
+
+bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF)
+ continue;
+
+ const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+ const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+ const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create dummy vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+ }
+
+ WithColor::error(errs(), ToolName)
+ << "No machine functions found to initialize vocabulary\n";
+ return false;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "MIR Vocabulary must be initialized for triplet generation.\n";
+ return Result;
+ }
+
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions
+ if (MI.isDebugInstr())
+ continue;
+
+ // Get opcode entity ID
+ unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+ LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
+ << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+ }
+
+ // Add "Arg" relationships for operands
+ unsigned ArgIndex = 0;
+ for (const MachineOperand &MO : MI.operands()) {
+ auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+ unsigned RelationID = MIRArgRelation + ArgIndex;
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+ LLVM_DEBUG({
+ std::string OperandStr = Vocab->getStringKey(OperandID);
+ dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
+ << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+
+ // Update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ Result.MaxRelation =
+ std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+ PrevOpcode = OpcodeID;
+ HasPrevOpcode = true;
+ }
+ }
+
+ return Result;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = generateTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+}
+
+void MIR2VecTool::writeTripletsToStream(const Module &M,
+ raw_ostream &OS) const {
+ auto Result = generateTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList MIR2VecTool::collectEntityMappings() const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary must be initialized for entity mappings.\n";
+ return {};
+ }
+
+ const unsigned EntityCount = Vocab->getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+}
+
+void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ writeEmbeddingsToStream(*MF, OS, Level);
+ }
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << MF.getName() << "\n";
+ return;
+ }
+
+ OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ OS << "Function vector: ";
+ Emb->getMFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ OS << "Basic block vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ OS << "MBB " << MBB.getName() << ": ";
+ Emb->getMBBVector(MBB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ OS << "Instruction vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ OS << MI << " -> ";
+ Emb->getMInstVector(MI).print(OS);
+ }
+ }
+ break;
+ }
+}
+
+} // namespace mir2vec
+
+} // namespace llvm
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/emb-tool.h
similarity index 98%
rename from llvm/tools/llvm-ir2vec/llvm-ir2vec.h
rename to llvm/tools/llvm-ir2vec/emb-tool.h
index 566c362edbd22..009bcec60108b 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/emb-tool.h
@@ -12,8 +12,8 @@
///
//===----------------------------------------------------------------------===//
-#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
-#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
+#define LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
@@ -44,7 +44,7 @@
#define DEBUG_TYPE "ir2vec"
namespace llvm {
-
+
/// Tool name for error reporting
static const char *ToolName = "llvm-ir2vec";
@@ -198,4 +198,4 @@ struct MIRContext {
} // namespace llvm
-#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
\ No newline at end of file
+#endif // LLVM_TOOLS_LLVM_IR2VEC_EMB_TOOL_H
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 6b70e09518fa7..a2b2f4e6a7aa8 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,7 +54,7 @@
///
//===----------------------------------------------------------------------===//
-#include "llvm-ir2vec.h"
+#include "emb-tool.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
@@ -147,167 +147,6 @@ static cl::opt<EmbeddingLevel>
namespace ir2vec {
-bool IR2VecTool::initializeVocabulary() {
- // Register and run the IR2Vec vocabulary analysis
- // The vocabulary file path is specified via --ir2vec-vocab-path global
- // option
- MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
- MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
- // This will throw an error if vocab is not found or invalid
- Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
- return Vocab->isValid();
-}
-
-TripletResult IR2VecTool::generateTriplets(const Function &F) const {
- if (F.isDeclaration())
- return {};
-
- TripletResult Result;
- Result.MaxRelation = 0;
-
- unsigned MaxRelation = NextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
-
- for (const BasicBlock &BB : F) {
- for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
- LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
- << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
- << '\t' << "Next\n");
- }
-
- // Add "Type" relationship
- Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
- LLVM_DEBUG(
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
-
- // Add "Arg" relationships
- unsigned ArgIndex = 0;
- for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getIndex(*U.get());
- unsigned RelationID = ArgRelation + ArgIndex;
- Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
- LLVM_DEBUG({
- StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
- Vocabulary::getOperandKind(U.get()));
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
- // Only update MaxRelation if there were operands
- if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- PrevOpcode = Opcode;
- HasPrevOpcode = true;
- }
- }
-
- Result.MaxRelation = MaxRelation;
- return Result;
-}
-
-TripletResult IR2VecTool::generateTriplets() const {
- TripletResult Result;
- Result.MaxRelation = NextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- TripletResult FuncResult = generateTriplets(F);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
- }
-
- return Result;
-}
-
-void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
- auto Result = generateTriplets();
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-}
-
-EntityList IR2VecTool::collectEntityMappings() {
- auto EntityLen = Vocabulary::getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- Result.push_back(Vocabulary::getStringKey(EntityID).str());
- return Result;
-}
-
-void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
- auto Entities = collectEntityMappings();
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
-}
-
-void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs())
- writeEmbeddingsToStream(F, OS, Level);
-}
-
-void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
- if (F.isDeclaration()) {
- OS << "Function " << F.getName() << " is a declaration, skipping.\n";
- return;
- }
-
- // Create embedder for this function
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
- OS << "Function: " << F.getName() << "\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel:
- Emb->getFunctionVector().print(OS);
- break;
- case BasicBlockLevel:
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
- }
- break;
- case InstructionLevel:
- for (const Instruction &I : instructions(F)) {
- OS << I;
- Emb->getInstVector(I).print(OS);
- }
- break;
- }
-}
-
/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -341,215 +180,6 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
-bool MIR2VecTool::initializeVocabulary(const Module &M) {
- MIR2VecVocabProvider Provider(MMI);
- auto VocabOrErr = Provider.getVocabulary(M);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to load MIR2Vec vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
-}
-
-bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF)
- continue;
-
- const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
- const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
- const MachineRegisterInfo &MRI = MF->getRegInfo();
-
- auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to create dummy vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
-
- WithColor::error(errs(), ToolName)
- << "No machine functions found to initialize vocabulary\n";
- return false;
-}
-
-TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "MIR Vocabulary must be initialized for triplet generation.\n";
- return Result;
- }
-
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- // Skip debug instructions
- if (MI.isDebugInstr())
- continue;
-
- // Get opcode entity ID
- unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
- LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
- << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
- }
-
- // Add "Arg" relationships for operands
- unsigned ArgIndex = 0;
- for (const MachineOperand &MO : MI.operands()) {
- auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
- unsigned RelationID = MIRArgRelation + ArgIndex;
- Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
- LLVM_DEBUG({
- std::string OperandStr = Vocab->getStringKey(OperandID);
- dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
- << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
-
- // Update MaxRelation if there were operands
- if (ArgIndex > 0)
- Result.MaxRelation =
- std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
- PrevOpcode = OpcodeID;
- HasPrevOpcode = true;
- }
- }
-
- return Result;
-}
-
-TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- TripletResult FuncResult = generateTriplets(*MF);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
- }
-
- return Result;
-}
-
-void MIR2VecTool::writeTripletsToStream(const Module &M,
- raw_ostream &OS) const {
- auto Result = generateTriplets(M);
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-}
-
-EntityList MIR2VecTool::collectEntityMappings() const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary must be initialized for entity mappings.\n";
- return {};
- }
-
- const unsigned EntityCount = Vocab->getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- Result.push_back(Vocab->getStringKey(EntityID));
-
- return Result;
-}
-
-void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
- auto Entities = collectEntityMappings();
- if (Entities.empty())
- return;
-
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
-}
-
-void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- writeEmbeddingsToStream(*MF, OS, Level);
- }
-}
-
-void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
-
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
- return;
- }
-
- OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel:
- OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
- break;
- case BasicBlockLevel:
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
- break;
- case InstructionLevel:
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
- }
- }
- break;
- }
-}
-
/// Setup MIR context from input file
Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
SMDiagnostic Err;
>From ac7cd83a0400a819716fb70d54d3ae399fa206dc Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Sun, 21 Dec 2025 22:32:35 +0530
Subject: [PATCH 2/5] nit commit - code formatting fixup
---
llvm/tools/llvm-ir2vec/emb-tool.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.h b/llvm/tools/llvm-ir2vec/emb-tool.h
index 009bcec60108b..24b8022ed7482 100644
--- a/llvm/tools/llvm-ir2vec/emb-tool.h
+++ b/llvm/tools/llvm-ir2vec/emb-tool.h
@@ -44,7 +44,7 @@
#define DEBUG_TYPE "ir2vec"
namespace llvm {
-
+
/// Tool name for error reporting
static const char *ToolName = "llvm-ir2vec";
>From 959ab6047c0405d57073000b4c814e0d1840d10b Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 23 Dec 2025 16:00:00 +0530
Subject: [PATCH 3/5] Work Commit. Creating a lib/ structure, and using it for
the first draft of the python bindings
---
llvm/tools/llvm-ir2vec/CMakeLists.txt | 29 ++---
.../tools/llvm-ir2vec/bindings/CMakeLists.txt | 14 +++
.../llvm-ir2vec/bindings/ir2vec_bindings.cpp | 107 ++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/CMakeLists.txt | 28 +++++
llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.cpp | 0
llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.h | 0
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 2 +-
7 files changed, 159 insertions(+), 21 deletions(-)
create mode 100644 llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
create mode 100644 llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
create mode 100644 llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
rename llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.cpp (100%)
rename llvm/tools/llvm-ir2vec/{ => lib}/emb-tool.h (100%)
diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index 9d5db8663fb38..d5956cde6892d 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -1,26 +1,15 @@
-set(LLVM_LINK_COMPONENTS
- # Core LLVM components for IR processing
- Analysis
- Core
- IRReader
- Support
-
- # Machine IR components (for -mode=mir)
- CodeGen
- MIRParser
-
- # Target initialization (required for MIR parsing)
- AllTargetsAsmParsers
- AllTargetsCodeGens
- AllTargetsDescs
- AllTargetsInfos
- TargetParser
- )
+# Build the library first
+add_subdirectory(lib)
+# Build the llvm-ir2vec executable
add_llvm_tool(llvm-ir2vec
llvm-ir2vec.cpp
- emb-tool.cpp
DEPENDS
intrinsics_gen
- )
+)
+
+# Link the executable against the library
+target_link_libraries(llvm-ir2vec PRIVATE EmbTool)
+
+add_subdirectory(bindings)
diff --git a/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
new file mode 100644
index 0000000000000..86f4ffb35b3ca
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/bindings/CMakeLists.txt
@@ -0,0 +1,14 @@
+find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
+find_package(pybind11 CONFIG QUIET)
+
+if(NOT pybind11_FOUND)
+ message(WARNING "pybind11 not found - skipping py_ir2vec")
+ return()
+endif()
+
+pybind11_add_module(py_ir2vec
+ ir2vec_bindings.cpp
+)
+
+# Link against the EmbTool library
+target_link_libraries(py_ir2vec PRIVATE EmbTool)
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
new file mode 100644
index 0000000000000..d9d4b9f07e637
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -0,0 +1,107 @@
+//===- ir2vec_bindings.cpp - Python Bindings for IR2Vec ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../lib/emb-tool.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace llvm::ir2vec;
+
+namespace {
+
+bool fileNotValid(const std::string &Filename) {
+ std::ifstream F(Filename, std::ios_base::in | std::ios_base::binary);
+ return !F.good();
+}
+
+std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
+ LLVMContext &Context) {
+ SMDiagnostic Err;
+ auto M = parseIRFile(Filename, Err, Context);
+ if (!M) {
+ Err.print(Filename.c_str(), outs());
+ throw std::runtime_error("Failed to parse IR file.");
+ }
+ return M;
+}
+
+class PyIR2VecTool {
+private:
+ std::unique_ptr<LLVMContext> Ctx;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<IR2VecTool> Tool;
+
+public:
+ PyIR2VecTool(const std::string &Filename) {
+ if (fileNotValid(Filename))
+ throw std::runtime_error("Invalid file path");
+
+ Ctx = std::make_unique<LLVMContext>();
+ M = getLLVMIR(Filename, *Ctx);
+ Tool = std::make_unique<IR2VecTool>(*M);
+ }
+
+ // Get entity mappings (vocabulary)
+ std::vector<std::string> getEntityMappings() {
+ return IR2VecTool::collectEntityMappings();
+ }
+
+ // Generate triplets for vocabulary training
+ py::dict generateTriplets() {
+ auto result = Tool->generateTriplets();
+ py::list triplets_list;
+
+ for (const auto &t : result.Triplets) {
+ triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
+ }
+
+ py::dict output;
+ output["max_relation"] = result.MaxRelation;
+ output["triplets"] = triplets_list;
+ return output;
+ }
+};
+
+} // anonymous namespace
+
+PYBIND11_MODULE(py_ir2vec, m) {
+ m.doc() = std::string("Python bindings for ") + ToolName + "\n\n" + ToolName +
+ " provides distributed representations for LLVM IR.\n\n"
+ "Example:\n"
+ " >>> import py_ir2vec\n"
+ " >>> tool = py_ir2vec." +
+ ToolName +
+ "Tool(\"test.ll\")\n"
+ " >>> entities = tool.getEntityMappings()\n"
+ " >>> triplets = tool.generateTriplets()";
+
+ // Main tool class
+ py::class_<PyIR2VecTool>(m, "IR2VecTool")
+ .def(py::init<const std::string &>(), py::arg("filename"),
+ "Initialize IR2Vec on an LLVM IR file\n"
+ "Args:\n"
+ " filename: Path to LLVM IR file (.bc or .ll)")
+ .def("getEntityMappings", &PyIR2VecTool::getEntityMappings,
+ "Get entity mappings (vocabulary)\n"
+ "Returns: list[str] - list of entity names where index is entity_id")
+ .def("generateTriplets", &PyIR2VecTool::generateTriplets,
+ "Generate triplets for vocabulary training\n"
+ "Returns: dict with 'max_relation' and 'triplets' keys");
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt b/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
new file mode 100644
index 0000000000000..f15761beb97df
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/lib/CMakeLists.txt
@@ -0,0 +1,28 @@
+set(LLVM_LINK_COMPONENTS
+ # Core LLVM components for IR processing
+ Analysis
+ Core
+ IRReader
+ Support
+
+ # Machine IR components (for -mode=mir)
+ CodeGen
+ MIRParser
+
+ # Target initialization (required for MIR parsing)
+ AllTargetsAsmParsers
+ AllTargetsCodeGens
+ AllTargetsDescs
+ AllTargetsInfos
+ TargetParser
+)
+
+add_llvm_library(EmbTool
+ emb-tool.cpp
+
+ LINK_COMPONENTS
+ ${LLVM_LINK_COMPONENTS}
+
+ DEPENDS
+ intrinsics_gen
+)
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.cpp b/llvm/tools/llvm-ir2vec/lib/emb-tool.cpp
similarity index 100%
rename from llvm/tools/llvm-ir2vec/emb-tool.cpp
rename to llvm/tools/llvm-ir2vec/lib/emb-tool.cpp
diff --git a/llvm/tools/llvm-ir2vec/emb-tool.h b/llvm/tools/llvm-ir2vec/lib/emb-tool.h
similarity index 100%
rename from llvm/tools/llvm-ir2vec/emb-tool.h
rename to llvm/tools/llvm-ir2vec/lib/emb-tool.h
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a2b2f4e6a7aa8..541585d881638 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,7 +54,7 @@
///
//===----------------------------------------------------------------------===//
-#include "emb-tool.h"
+#include "lib/emb-tool.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
>From 7e0cc13677dd52f6843f6d586fde92704366fc2a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 24 Dec 2025 00:38:48 +0530
Subject: [PATCH 4/5] Work commit - python bindings reworked to include a
compulsory vocab override input. Added getTriplets function as well
---
llvm/include/llvm/Analysis/IR2Vec.h | 2 +-
llvm/lib/Analysis/IR2Vec.cpp | 27 +++-
.../llvm-ir2vec/bindings/ir2vec_bindings.cpp | 116 ++++++++++++------
3 files changed, 99 insertions(+), 46 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 5957a3743f22e..585dffe715be2 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -615,7 +615,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
using VocabMap = std::map<std::string, ir2vec::Embedding>;
std::optional<ir2vec::VocabStorage> Vocab;
- Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
+ Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab, VocabMap &TypeVocab,
VocabMap &ArgVocab);
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
VocabMap &ArgVocab);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 85b5372c961c1..a47ced94ee676 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -62,6 +62,17 @@ cl::opt<IR2VecKind> IR2VecEmbeddingKind(
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
cl::cat(IR2VecCategory));
+static std::optional<std::string> VocabOverride;
+void setIR2VecVocabPath(StringRef Path) {
+ if (Path.empty()) VocabOverride = std::nullopt;
+ else VocabOverride = Path.str();
+}
+
+StringRef getIR2VecVocabPath() {
+ return VocabOverride ? StringRef(*VocabOverride)
+ : StringRef(VocabFile.getValue());
+}
+
} // namespace ir2vec
} // namespace llvm
@@ -482,12 +493,13 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
-Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
+Error IR2VecVocabAnalysis::readVocabulary(StringRef EffectivePath,
+ VocabMap &OpcVocab,
VocabMap &TypeVocab,
VocabMap &ArgVocab) {
- auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
+ auto BufOrError = MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
if (!BufOrError)
- return createFileError(VocabFile, BufOrError.getError());
+ return createFileError(EffectivePath.str(), BufOrError.getError());
auto Content = BufOrError.get()->getBuffer();
@@ -615,8 +627,11 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
if (Vocab.has_value())
return Vocabulary(std::move(Vocab.value()));
+ StringRef EffectivePath =
+ VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
+
// Otherwise, try to read from the vocabulary file.
- if (VocabFile.empty()) {
+ if (EffectivePath.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
"set it using --ir2vec-vocab-path");
@@ -624,7 +639,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
}
VocabMap OpcVocab, TypeVocab, ArgVocab;
- if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
+ if (auto Err = readVocabulary(EffectivePath, OpcVocab, TypeVocab, ArgVocab)) {
emitError(std::move(Err), *Ctx);
return Vocabulary();
}
@@ -694,4 +709,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
Entry.print(OS);
}
return PreservedAnalyses::all();
-}
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
index d9d4b9f07e637..b4658fc60a3ed 100644
--- a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -15,6 +15,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
+#include <pybind11/numpy.h>
#include <fstream>
#include <memory>
@@ -24,6 +25,13 @@ namespace py = pybind11;
using namespace llvm;
using namespace llvm::ir2vec;
+namespace llvm {
+namespace ir2vec {
+void setIR2VecVocabPath(StringRef Path);
+StringRef getIR2VecVocabPath();
+} // namespace ir2vec
+} // namespace llvm
+
namespace {
bool fileNotValid(const std::string &Filename) {
@@ -32,7 +40,7 @@ bool fileNotValid(const std::string &Filename) {
}
std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
- LLVMContext &Context) {
+ LLVMContext &Context) {
SMDiagnostic Err;
auto M = parseIRFile(Filename, Err, Context);
if (!M) {
@@ -49,59 +57,89 @@ class PyIR2VecTool {
std::unique_ptr<IR2VecTool> Tool;
public:
- PyIR2VecTool(const std::string &Filename) {
- if (fileNotValid(Filename))
- throw std::runtime_error("Invalid file path");
+ PyIR2VecTool(std::string Filename, std::string Mode,
+ std::string VocabOverride) {
+ if (fileNotValid(Filename))
+ throw std::runtime_error("Invalid file path");
- Ctx = std::make_unique<LLVMContext>();
- M = getLLVMIR(Filename, *Ctx);
- Tool = std::make_unique<IR2VecTool>(*M);
- }
+ if (Mode != "sym" && Mode != "fa")
+ throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
- // Get entity mappings (vocabulary)
- std::vector<std::string> getEntityMappings() {
- return IR2VecTool::collectEntityMappings();
+ if (VocabOverride.empty())
+ throw std::runtime_error("Error - Empty Vocab Path not allowed");
+
+ setIR2VecVocabPath(VocabOverride);
+
+ Ctx = std::make_unique<LLVMContext>();
+ M = getLLVMIR(Filename, *Ctx);
+ Tool = std::make_unique<IR2VecTool>(*M);
+
+ bool Ok = Tool->initializeVocabulary();
+ if (!Ok)
+ throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
}
- // Generate triplets for vocabulary training
py::dict generateTriplets() {
auto result = Tool->generateTriplets();
py::list triplets_list;
-
- for (const auto &t : result.Triplets) {
+ for (const auto& t : result.Triplets) {
triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
}
- py::dict output;
- output["max_relation"] = result.MaxRelation;
- output["triplets"] = triplets_list;
- return output;
+ return py::dict(
+ py::arg("max_relation") = result.MaxRelation,
+ py::arg("triplets") = triplets_list
+ );
+ }
+
+ EntityList collectEntityMappings() {
+ return IR2VecTool::collectEntityMappings();
}
};
-} // anonymous namespace
+} // namespace
PYBIND11_MODULE(py_ir2vec, m) {
- m.doc() = std::string("Python bindings for ") + ToolName + "\n\n" + ToolName +
- " provides distributed representations for LLVM IR.\n\n"
- "Example:\n"
- " >>> import py_ir2vec\n"
- " >>> tool = py_ir2vec." +
- ToolName +
- "Tool(\"test.ll\")\n"
- " >>> entities = tool.getEntityMappings()\n"
- " >>> triplets = tool.generateTriplets()";
+ m.doc() = R"pbdoc(
+ Python bindings for LLVM IR2Vec
+
+ IR2Vec provides distributed representations for LLVM IR.
+ Example:
+ >>> import py_ir2vec as M
+ >>> tool = M.initEmbedding(
+ filename=tesy.ll,
+ mode="sym",
+ vocab_override=vocab_path
+ )
+ )pbdoc";
// Main tool class
py::class_<PyIR2VecTool>(m, "IR2VecTool")
- .def(py::init<const std::string &>(), py::arg("filename"),
- "Initialize IR2Vec on an LLVM IR file\n"
- "Args:\n"
- " filename: Path to LLVM IR file (.bc or .ll)")
- .def("getEntityMappings", &PyIR2VecTool::getEntityMappings,
- "Get entity mappings (vocabulary)\n"
- "Returns: list[str] - list of entity names where index is entity_id")
- .def("generateTriplets", &PyIR2VecTool::generateTriplets,
- "Generate triplets for vocabulary training\n"
- "Returns: dict with 'max_relation' and 'triplets' keys");
-}
\ No newline at end of file
+ .def(py::init<std::string, std::string, std::string>(),
+ py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
+ "Initialize IR2Vec on an LLVM IR file\n"
+ "Args:\n"
+ " filename: Path to LLVM IR file (.bc or .ll)\n"
+ " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+ " vocab_override: Path to vocabulary file")
+ .def("generateTriplets", &PyIR2VecTool::generateTriplets,
+ "Generate triplets for vocabulary training\n"
+ "Returns: TripletResult Dict with max_relation and list of triplets")
+ .def("getEntityMappings", &PyIR2VecTool::collectEntityMappings,
+ "Get entity mappings (vocabulary)\n"
+ "Returns: list[str] - list of entity names where index is entity_id");
+
+ m.def(
+ "initEmbedding",
+ [](std::string filename, std::string mode, std::string vocab_override) {
+ return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
+ },
+ py::arg("filename"), py::arg("mode") = "sym",
+ py::arg("vocab_override"),
+ "Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
+ "Args:\n"
+ " filename: Path to LLVM IR file (.bc or .ll)\n"
+ " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+ " vocab_override: Path to vocabulary file\n"
+ "Returns: IR2VecTool instance");
+}
>From d6cad46dc84d6c85cc42b64c73df8bf90652b56d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 24 Dec 2025 00:48:15 +0530
Subject: [PATCH 5/5] Nit commit - code formatting fixup
---
llvm/include/llvm/Analysis/IR2Vec.h | 4 +-
llvm/lib/Analysis/IR2Vec.cpp | 13 +--
.../llvm-ir2vec/bindings/ir2vec_bindings.cpp | 90 +++++++++----------
3 files changed, 54 insertions(+), 53 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 585dffe715be2..dcb9783bfa91c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -615,8 +615,8 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
using VocabMap = std::map<std::string, ir2vec::Embedding>;
std::optional<ir2vec::VocabStorage> Vocab;
- Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab, VocabMap &TypeVocab,
- VocabMap &ArgVocab);
+ Error readVocabulary(StringRef EffectivePath, VocabMap &OpcVocab,
+ VocabMap &TypeVocab, VocabMap &ArgVocab);
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
VocabMap &ArgVocab);
void emitError(Error Err, LLVMContext &Ctx);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index a47ced94ee676..16822a6828d8c 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -64,13 +64,15 @@ cl::opt<IR2VecKind> IR2VecEmbeddingKind(
static std::optional<std::string> VocabOverride;
void setIR2VecVocabPath(StringRef Path) {
- if (Path.empty()) VocabOverride = std::nullopt;
- else VocabOverride = Path.str();
+ if (Path.empty())
+ VocabOverride = std::nullopt;
+ else
+ VocabOverride = Path.str();
}
StringRef getIR2VecVocabPath() {
return VocabOverride ? StringRef(*VocabOverride)
- : StringRef(VocabFile.getValue());
+ : StringRef(VocabFile.getValue());
}
} // namespace ir2vec
@@ -497,7 +499,8 @@ Error IR2VecVocabAnalysis::readVocabulary(StringRef EffectivePath,
VocabMap &OpcVocab,
VocabMap &TypeVocab,
VocabMap &ArgVocab) {
- auto BufOrError = MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
+ auto BufOrError =
+ MemoryBuffer::getFileOrSTDIN(EffectivePath, /*IsText=*/true);
if (!BufOrError)
return createFileError(EffectivePath.str(), BufOrError.getError());
@@ -628,7 +631,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
return Vocabulary(std::move(Vocab.value()));
StringRef EffectivePath =
- VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
+ VocabOverride ? StringRef(*VocabOverride) : VocabFile.getValue();
// Otherwise, try to read from the vocabulary file.
if (EffectivePath.empty()) {
diff --git a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
index b4658fc60a3ed..784c4f0fe37dd 100644
--- a/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
+++ b/llvm/tools/llvm-ir2vec/bindings/ir2vec_bindings.cpp
@@ -13,9 +13,9 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
-#include <pybind11/numpy.h>
#include <fstream>
#include <memory>
@@ -40,7 +40,7 @@ bool fileNotValid(const std::string &Filename) {
}
std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
- LLVMContext &Context) {
+ LLVMContext &Context) {
SMDiagnostic Err;
auto M = parseIRFile(Filename, Err, Context);
if (!M) {
@@ -58,38 +58,36 @@ class PyIR2VecTool {
public:
PyIR2VecTool(std::string Filename, std::string Mode,
- std::string VocabOverride) {
- if (fileNotValid(Filename))
- throw std::runtime_error("Invalid file path");
+ std::string VocabOverride) {
+ if (fileNotValid(Filename))
+ throw std::runtime_error("Invalid file path");
- if (Mode != "sym" && Mode != "fa")
- throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
+ if (Mode != "sym" && Mode != "fa")
+ throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
- if (VocabOverride.empty())
- throw std::runtime_error("Error - Empty Vocab Path not allowed");
+ if (VocabOverride.empty())
+ throw std::runtime_error("Error - Empty Vocab Path not allowed");
- setIR2VecVocabPath(VocabOverride);
+ setIR2VecVocabPath(VocabOverride);
- Ctx = std::make_unique<LLVMContext>();
- M = getLLVMIR(Filename, *Ctx);
- Tool = std::make_unique<IR2VecTool>(*M);
+ Ctx = std::make_unique<LLVMContext>();
+ M = getLLVMIR(Filename, *Ctx);
+ Tool = std::make_unique<IR2VecTool>(*M);
- bool Ok = Tool->initializeVocabulary();
- if (!Ok)
- throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
+ bool Ok = Tool->initializeVocabulary();
+ if (!Ok)
+ throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
}
py::dict generateTriplets() {
auto result = Tool->generateTriplets();
py::list triplets_list;
- for (const auto& t : result.Triplets) {
+ for (const auto &t : result.Triplets) {
triplets_list.append(py::make_tuple(t.Head, t.Tail, t.Relation));
}
- return py::dict(
- py::arg("max_relation") = result.MaxRelation,
- py::arg("triplets") = triplets_list
- );
+ return py::dict(py::arg("max_relation") = result.MaxRelation,
+ py::arg("triplets") = triplets_list);
}
EntityList collectEntityMappings() {
@@ -115,31 +113,31 @@ PYBIND11_MODULE(py_ir2vec, m) {
// Main tool class
py::class_<PyIR2VecTool>(m, "IR2VecTool")
- .def(py::init<std::string, std::string, std::string>(),
- py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
- "Initialize IR2Vec on an LLVM IR file\n"
- "Args:\n"
- " filename: Path to LLVM IR file (.bc or .ll)\n"
- " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
- " vocab_override: Path to vocabulary file")
- .def("generateTriplets", &PyIR2VecTool::generateTriplets,
- "Generate triplets for vocabulary training\n"
- "Returns: TripletResult Dict with max_relation and list of triplets")
- .def("getEntityMappings", &PyIR2VecTool::collectEntityMappings,
- "Get entity mappings (vocabulary)\n"
- "Returns: list[str] - list of entity names where index is entity_id");
+ .def(py::init<std::string, std::string, std::string>(),
+ py::arg("filename"), py::arg("mode"), py::arg("vocab_override"),
+ "Initialize IR2Vec on an LLVM IR file\n"
+ "Args:\n"
+ " filename: Path to LLVM IR file (.bc or .ll)\n"
+ " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+ " vocab_override: Path to vocabulary file")
+ .def("generateTriplets", &PyIR2VecTool::generateTriplets,
+ "Generate triplets for vocabulary training\n"
+ "Returns: TripletResult Dict with max_relation and list of triplets")
+ .def(
+ "getEntityMappings", &PyIR2VecTool::collectEntityMappings,
+ "Get entity mappings (vocabulary)\n"
+ "Returns: list[str] - list of entity names where index is entity_id");
m.def(
- "initEmbedding",
- [](std::string filename, std::string mode, std::string vocab_override) {
- return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
- },
- py::arg("filename"), py::arg("mode") = "sym",
- py::arg("vocab_override"),
- "Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
- "Args:\n"
- " filename: Path to LLVM IR file (.bc or .ll)\n"
- " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
- " vocab_override: Path to vocabulary file\n"
- "Returns: IR2VecTool instance");
+ "initEmbedding",
+ [](std::string filename, std::string mode, std::string vocab_override) {
+ return std::make_unique<PyIR2VecTool>(filename, mode, vocab_override);
+ },
+ py::arg("filename"), py::arg("mode") = "sym", py::arg("vocab_override"),
+ "Initialize IR2Vec on an LLVM IR file and return an IR2VecTool\n"
+ "Args:\n"
+ " filename: Path to LLVM IR file (.bc or .ll)\n"
+ " mode: 'sym' for symbolic (default) or 'fa' for flow-aware\n"
+ " vocab_override: Path to vocabulary file\n"
+ "Returns: IR2VecTool instance");
}
More information about the llvm-commits
mailing list