[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 10 01:23:31 PST 2025
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078
>From e918ecdd0d5591802c68e1b78d5a5c34d5fa3d4f Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH 1/5] Refactoring llvm-ir2vec.cpp for better separation of
concerns in the Tooling classes
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 431 +++++++++++++++++--------
1 file changed, 290 insertions(+), 141 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7b8d3f093a3d1..a9af7ab005982 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -64,6 +64,7 @@
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -151,6 +152,29 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: [{bb_name, Embedding}]
+using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Instruction embeddings: [{instruction_string, Embedding}]
+using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
+
+struct Triplet {
+ unsigned Head;
+ unsigned Tail;
+ unsigned Relation;
+};
+
+struct TripletResult {
+ unsigned MaxRelation;
+ std::vector<Triplet> Triplets;
+};
+
namespace ir2vec {
/// Relation types for triplet generation
@@ -182,32 +206,14 @@ class IR2VecTool {
return Vocab->isValid();
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(raw_ostream &OS) const {
- unsigned MaxRelation = NextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
+ TripletResult getTriplets(const Function &F) const {
+ TripletResult Result;
+ Result.MaxRelation = 0;
- for (const Function &F : M) {
- unsigned FuncMaxRelation = generateTriplets(F, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration())
- return 0;
+ return Result;
- unsigned MaxRelation = 1;
+ unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
@@ -216,56 +222,139 @@ class IR2VecTool {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
- // Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
}
- // Add "Type" relationship
- OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
- // Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
- OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
- // Only update MaxRelation if there were operands
+
if (ArgIndex > 0) {
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
}
+
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ Result.MaxRelation = MaxRelation;
+ return Result;
}
- /// Dump entity ID to string mappings
- static void generateEntityMappings(raw_ostream &OS) {
+ TripletResult getTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M) {
+ TripletResult FuncResult = getTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(raw_ostream &OS) const {
+ auto Result = getTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
- OS << EntityLen << "\n";
+ EntityList Result;
+ Result.reserve(EntityLen);
+
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+ return Result;
+ }
+
+ /// Dump entity ID to string mappings
+ static void generateEntityMappings(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ }
+
+ ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ if (F.isDeclaration())
+ return {};
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ return Emb ? Emb->getFunctionVector() : Embedding{};
+ }
+
+ BBVecList getBBEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const BasicBlock &BB : F)
+ Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+ return Result;
+ }
+
+ InstVecList getInstEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ InstVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const Instruction &I : instructions(F)) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << I;
+ Result.push_back({InstStr, Emb->getInstVector(I)});
+ }
+ return Result;
}
/// Generate embeddings for the entire module
@@ -282,44 +371,31 @@ class IR2VecTool {
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
- // Create embedder for this function
- assert(Vocab->isValid() && "Vocabulary is not valid");
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
OS << "Function: " << F.getName() << "\n";
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- Emb->getFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
- }
- break;
- }
- case InstructionLevel: {
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
- I.print(OS);
- Emb->getInstVector(I).print(OS);
- }
+ auto printListLevel = [&](const auto& list, const char* suffix_str) {
+ for (const auto& [name, embedding] : list) {
+ OS << name << suffix_str;
+ embedding.print(OS);
}
- break;
- }
+ };
+
+ switch (Level) {
+ case EmbeddingLevel::FunctionLevel:
+ getFunctionEmbedding(F).print(OS);
+ break;
+ case EmbeddingLevel::BasicBlockLevel:
+ printListLevel(getBBEmbeddings(F), ":");
+ break;
+ case EmbeddingLevel::InstructionLevel:
+ printListLevel(getInstEmbeddings(F), "");
+ break;
}
}
};
@@ -421,47 +497,22 @@ class MIR2VecTool {
<< "No machine functions found to initialize vocabulary\n";
return false;
}
-
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(const Module &M, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M.getFunctionDefs()) {
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single machine function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
+
+ /// Get triplets for a single machine function
+ /// Returns TripletResult containing MaxRelation and vector of Triplets
+ TripletResult getTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
- return MaxRelation;
+ return Result;
}
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
@@ -473,8 +524,7 @@ class MIR2VecTool {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
- << '\n';
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs()
<< Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -485,7 +535,7 @@ class MIR2VecTool {
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
- OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -497,28 +547,82 @@ class MIR2VecTool {
// Update MaxRelation if there were operands
if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+ Result.MaxRelation = std::max(Result.MaxRelation,
+ MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ return Result;
}
+
+ /// Get triplets for the entire module
+ /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+ TripletResult getTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
- /// Generate entity mappings with vocabulary
- void generateEntityMappings(raw_ostream &OS) const {
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = getTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module and write to output stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(const Module &M, raw_ostream &OS) const {
+ auto Result = getTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ /// Collect entity mappings
+ /// Returns EntityList containing all entity strings
+ EntityList collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
- return;
+ return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
- OS << EntityCount << "\n";
+ EntityList Result;
+ Result.reserve(EntityCount);
+
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+ }
+
+ /// Generate entity mappings and write to output stream
+ void generateEntityMappings(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
/// Generate embeddings for all machine functions in the module
@@ -541,48 +645,93 @@ class MIR2VecTool {
}
}
- /// Generate embeddings for a specific machine function
- void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ /// Get machine function embedding
+ /// Returns Embedding for the entire machine function
+ ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
+ return {};
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
+ return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+ }
+
+ /// Get machine basic block embeddings
+ /// Returns BBVecList containing (name, embedding) pairs for all MBBs
+ BBVecList getMBBEmbeddings(MachineFunction &MF) const {
+ BBVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF)
+ Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+
+ return Result;
+ }
+
+ /// Get machine instruction embeddings
+ /// Returns InstVecList containing (instruction_string, embedding) pairs
+ InstVecList getMInstEmbeddings(MachineFunction &MF) const {
+ InstVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << MI;
+ Result.push_back({InstStr, Emb->getMInstVector(MI)});
+ }
+ }
+
+ return Result;
+ }
+
+ /// Generate embeddings for a specific machine function
+ void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+ auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+ OS << label << ":\n";
+ for (const auto& [name, embedding] : list) {
+ OS << prefix_str << name << suffix_str;
+ embedding.print(OS);
+ }
+ };
+
// Generate embeddings based on the specified level
switch (Level) {
- case FunctionLevel: {
+ case FunctionLevel:
OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
+ getMFunctionEmbedding(MF).print(OS);
break;
- }
- case BasicBlockLevel: {
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
+ case BasicBlockLevel:
+ printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
break;
- }
- case InstructionLevel: {
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
- }
- }
+ case InstructionLevel:
+ printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
break;
}
- }
}
const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
>From 51f359ea0158fb4959e55091fd940780e7ac6c66 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 2 Dec 2025 12:42:28 +0530
Subject: [PATCH 2/5] Debug Commit - Clang Format fixup
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 59 +++++++++++++-------------
1 file changed, 29 insertions(+), 30 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a9af7ab005982..2511983c5d490 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -58,13 +58,13 @@
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
-#include "llvm/IR/InstIterator.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -225,16 +225,16 @@ class IR2VecTool {
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
}
Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
@@ -246,7 +246,7 @@ class IR2VecTool {
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
@@ -272,9 +272,8 @@ class IR2VecTool {
for (const Function &F : M) {
TripletResult FuncResult = getTriplets(F);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(),
- FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
}
return Result;
@@ -379,23 +378,23 @@ class IR2VecTool {
OS << "Function: " << F.getName() << "\n";
- auto printListLevel = [&](const auto& list, const char* suffix_str) {
- for (const auto& [name, embedding] : list) {
+ auto printListLevel = [&](const auto &list, const char *suffix_str) {
+ for (const auto &[name, embedding] : list) {
OS << name << suffix_str;
embedding.print(OS);
}
};
switch (Level) {
- case EmbeddingLevel::FunctionLevel:
- getFunctionEmbedding(F).print(OS);
- break;
- case EmbeddingLevel::BasicBlockLevel:
- printListLevel(getBBEmbeddings(F), ":");
- break;
- case EmbeddingLevel::InstructionLevel:
- printListLevel(getInstEmbeddings(F), "");
- break;
+ case EmbeddingLevel::FunctionLevel:
+ getFunctionEmbedding(F).print(OS);
+ break;
+ case EmbeddingLevel::BasicBlockLevel:
+ printListLevel(getBBEmbeddings(F), ":");
+ break;
+ case EmbeddingLevel::InstructionLevel:
+ printListLevel(getInstEmbeddings(F), "");
+ break;
}
}
};
@@ -497,7 +496,7 @@ class MIR2VecTool {
<< "No machine functions found to initialize vocabulary\n";
return false;
}
-
+
/// Get triplets for a single machine function
/// Returns TripletResult containing MaxRelation and vector of Triplets
TripletResult getTriplets(const MachineFunction &MF) const {
@@ -547,8 +546,8 @@ class MIR2VecTool {
// Update MaxRelation if there were operands
if (ArgIndex > 0)
- Result.MaxRelation = std::max(Result.MaxRelation,
- MIRArgRelation + ArgIndex - 1);
+ Result.MaxRelation =
+ std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
@@ -557,7 +556,7 @@ class MIR2VecTool {
return Result;
}
-
+
/// Get triplets for the entire module
/// Returns TripletResult containing aggregated MaxRelation and all Triplets
TripletResult getTriplets(const Module &M) const {
@@ -577,9 +576,8 @@ class MIR2VecTool {
TripletResult FuncResult = getTriplets(*MF);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(),
- FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
}
return Result;
@@ -711,9 +709,10 @@ class MIR2VecTool {
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
- auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+ auto printListLevel = [&](const auto &list, const char *label,
+ const char *prefix_str, const char *suffix_str) {
OS << label << ":\n";
- for (const auto& [name, embedding] : list) {
+ for (const auto &[name, embedding] : list) {
OS << prefix_str << name << suffix_str;
embedding.print(OS);
}
>From 2786c97c17deefcbef3ba92f3061874e166211c4 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 2 Dec 2025 16:44:25 +0530
Subject: [PATCH 3/5] Debug Commit - Resolving comments. To be squashed later
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 186 +++++++++++++------------
1 file changed, 98 insertions(+), 88 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 2511983c5d490..d3afbbc6f87e6 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -152,18 +152,6 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
-/// Basic block embeddings: [{bb_name, Embedding}]
-using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
-
-/// Instruction embeddings: [{instruction_string, Embedding}]
-using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
struct Triplet {
unsigned Head;
unsigned Tail;
@@ -176,6 +164,18 @@ struct TripletResult {
};
namespace ir2vec {
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: [{bb_ptr, Embedding}]
+using BBVecList = std::vector<std::pair<const BasicBlock *, ir2vec::Embedding>>;
+
+/// Instruction embeddings: [{instruction_ptr, Embedding}]
+using InstVecList =
+ std::vector<std::pair<const Instruction *, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
/// Relation types for triplet generation
enum RelationType {
@@ -206,13 +206,13 @@ class IR2VecTool {
return Vocab->isValid();
}
- TripletResult getTriplets(const Function &F) const {
+ TripletResult generateTriplets(const Function &F) const {
+ if (F.isDeclaration())
+ return {};
+
TripletResult Result;
Result.MaxRelation = 0;
- if (F.isDeclaration())
- return Result;
-
unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
@@ -222,6 +222,7 @@ class IR2VecTool {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+ // Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
@@ -230,12 +231,14 @@ class IR2VecTool {
<< "Next\n");
}
+ // Add "Type" relationship
Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
<< '\t' << "Type\n");
+ // Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
@@ -251,10 +254,9 @@ class IR2VecTool {
++ArgIndex;
}
-
- if (ArgIndex > 0) {
+ // Only update MaxRelation if there were operands
+ if (ArgIndex > 0)
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- }
PrevOpcode = Opcode;
HasPrevOpcode = true;
@@ -265,12 +267,14 @@ class IR2VecTool {
return Result;
}
- TripletResult getTriplets() const {
+ TripletResult generateTriplets() const {
TripletResult Result;
Result.MaxRelation = NextRelation;
for (const Function &F : M) {
- TripletResult FuncResult = getTriplets(F);
+ if (F.isDeclaration())
+ continue;
+ TripletResult FuncResult = generateTriplets(F);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
@@ -282,18 +286,15 @@ class IR2VecTool {
/// Generate triplets for the module
/// Output format: MAX_RELATION=N header followed by relationships
void generateTriplets(raw_ostream &OS) const {
- auto Result = getTriplets();
+ auto Result = generateTriplets();
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets) {
+ for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
- }
}
static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
EntityList Result;
- Result.reserve(EntityLen);
-
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
Result.push_back(Vocabulary::getStringKey(EntityID).str());
@@ -315,23 +316,27 @@ class IR2VecTool {
return {};
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- return Emb ? Emb->getFunctionVector() : Embedding{};
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << F.getName() << "\n";
+ return {};
+ }
+ return Emb->getFunctionVector();
}
BBVecList getBBEmbeddings(const Function &F) const {
assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
- BBVecList Result;
-
if (F.isDeclaration())
- return Result;
+ return {};
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
if (!Emb)
- return Result;
+ return {};
+ BBVecList Result;
for (const BasicBlock &BB : F)
- Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+ Result.push_back({&BB, Emb->getBBVector(BB)});
return Result;
}
@@ -339,20 +344,17 @@ class IR2VecTool {
InstVecList getInstEmbeddings(const Function &F) const {
assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
- InstVecList Result;
-
if (F.isDeclaration())
- return Result;
+ return {};
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
if (!Emb)
- return Result;
+ return {};
+
+ InstVecList Result;
+ for (const Instruction &I : instructions(F))
+ Result.push_back({&I, Emb->getInstVector(I)});
- for (const Instruction &I : instructions(F)) {
- std::string InstStr;
- raw_string_ostream(InstStr) << I;
- Result.push_back({InstStr, Emb->getInstVector(I)});
- }
return Result;
}
@@ -378,22 +380,21 @@ class IR2VecTool {
OS << "Function: " << F.getName() << "\n";
- auto printListLevel = [&](const auto &list, const char *suffix_str) {
- for (const auto &[name, embedding] : list) {
- OS << name << suffix_str;
- embedding.print(OS);
- }
- };
-
switch (Level) {
case EmbeddingLevel::FunctionLevel:
getFunctionEmbedding(F).print(OS);
break;
case EmbeddingLevel::BasicBlockLevel:
- printListLevel(getBBEmbeddings(F), ":");
+ for (const auto &[BB, embedding] : getBBEmbeddings(F)) {
+ OS << BB->getName() << ":";
+ embedding.print(OS);
+ }
break;
case EmbeddingLevel::InstructionLevel:
- printListLevel(getInstEmbeddings(F), "");
+ for (const auto &[I, embedding] : getInstEmbeddings(F)) {
+ OS << *I;
+ embedding.print(OS);
+ }
break;
}
}
@@ -430,6 +431,19 @@ Error processModule(Module &M, raw_ostream &OS) {
} // namespace ir2vec
namespace mir2vec {
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Machine basic block embeddings: [{mbb_ptr, Embedding}]
+using MBBVecList =
+ std::vector<std::pair<const MachineBasicBlock *, ir2vec::Embedding>>;
+
+/// Machine instruction embeddings: [{minstr_ptr, Embedding}]
+using MInstVecList =
+ std::vector<std::pair<const MachineInstr *, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
/// Relation types for MIR2Vec triplet generation
enum MIRRelationType {
@@ -499,7 +513,7 @@ class MIR2VecTool {
/// Get triplets for a single machine function
/// Returns TripletResult containing MaxRelation and vector of Triplets
- TripletResult getTriplets(const MachineFunction &MF) const {
+ TripletResult generateTriplets(const MachineFunction &MF) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
@@ -559,7 +573,7 @@ class MIR2VecTool {
/// Get triplets for the entire module
/// Returns TripletResult containing aggregated MaxRelation and all Triplets
- TripletResult getTriplets(const Module &M) const {
+ TripletResult generateTriplets(const Module &M) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
@@ -574,7 +588,7 @@ class MIR2VecTool {
continue;
}
- TripletResult FuncResult = getTriplets(*MF);
+ TripletResult FuncResult = generateTriplets(*MF);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
@@ -586,11 +600,10 @@ class MIR2VecTool {
/// Generate triplets for the module and write to output stream
/// Output format: MAX_RELATION=N header followed by relationships
void generateTriplets(const Module &M, raw_ostream &OS) const {
- auto Result = getTriplets(M);
+ auto Result = generateTriplets(M);
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets) {
+ for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
- }
}
/// Collect entity mappings
@@ -604,8 +617,6 @@ class MIR2VecTool {
const unsigned EntityCount = Vocab->getCanonicalSize();
EntityList Result;
- Result.reserve(EntityCount);
-
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
Result.push_back(Vocab->getStringKey(EntityID));
@@ -652,49 +663,49 @@ class MIR2VecTool {
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << MF.getName() << "\n";
+ return {};
+ }
+ return Emb->getMFunctionVector();
}
/// Get machine basic block embeddings
- /// Returns BBVecList containing (name, embedding) pairs for all MBBs
- BBVecList getMBBEmbeddings(MachineFunction &MF) const {
- BBVecList Result;
-
+ /// Returns MBBVecList containing (name, embedding) pairs for all MBBs
+ MBBVecList getMBBEmbeddings(MachineFunction &MF) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return Result;
+ return {};
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
if (!Emb)
- return Result;
+ return {};
+ MBBVecList Result;
for (const MachineBasicBlock &MBB : MF)
- Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+ Result.push_back({&MBB, Emb->getMBBVector(MBB)});
return Result;
}
/// Get machine instruction embeddings
- /// Returns InstVecList containing (instruction_string, embedding) pairs
- InstVecList getMInstEmbeddings(MachineFunction &MF) const {
- InstVecList Result;
-
+ /// Returns MInstVecList containing (instruction_string, embedding) pairs
+ MInstVecList getMInstEmbeddings(MachineFunction &MF) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return Result;
+ return {};
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
if (!Emb)
- return Result;
+ return {};
+ MInstVecList Result;
for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- std::string InstStr;
- raw_string_ostream(InstStr) << MI;
- Result.push_back({InstStr, Emb->getMInstVector(MI)});
- }
+ for (const MachineInstr &MI : MBB)
+ Result.push_back({&MI, Emb->getMInstVector(MI)});
}
return Result;
@@ -709,15 +720,6 @@ class MIR2VecTool {
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
- auto printListLevel = [&](const auto &list, const char *label,
- const char *prefix_str, const char *suffix_str) {
- OS << label << ":\n";
- for (const auto &[name, embedding] : list) {
- OS << prefix_str << name << suffix_str;
- embedding.print(OS);
- }
- };
-
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
@@ -725,10 +727,18 @@ class MIR2VecTool {
getMFunctionEmbedding(MF).print(OS);
break;
case BasicBlockLevel:
- printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
+ OS << "Basic block vectors:\n";
+ for (const auto &[MBB, embedding] : getMBBEmbeddings(MF)) {
+ OS << "MBB " << MBB->getName() << ": ";
+ embedding.print(OS);
+ }
break;
case InstructionLevel:
- printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
+ OS << "Instruction vectors:\n";
+ for (const auto &[MI, embedding] : getMInstEmbeddings(MF)) {
+ OS << *MI << " -> ";
+ embedding.print(OS);
+ }
break;
}
}
>From 555f68de511ab324fb4e2618e3626e49ce0b5f43 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 8 Dec 2025 16:37:58 +0530
Subject: [PATCH 4/5] Debug Commit - Resolving comments. To be squashed later
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 252 +++++++++++--------------
1 file changed, 109 insertions(+), 143 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index d3afbbc6f87e6..5b62142652990 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -152,31 +152,25 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
struct Triplet {
- unsigned Head;
- unsigned Tail;
- unsigned Relation;
+ unsigned Head = 0; ///< Index of the head entity in the entity list
+ unsigned Tail = 0; ///< Index of the tail entity in the entity list
+ unsigned Relation = 0; ///< Relation type (see RelationType enum)
};
+/// Result structure containing all generated triplets and metadata
struct TripletResult {
- unsigned MaxRelation;
- std::vector<Triplet> Triplets;
+ unsigned MaxRelation =
+ 0; ///< Highest relation index used (for ArgRelation + N)
+ std::vector<Triplet> Triplets; ///< Collection of all generated triplets
};
-namespace ir2vec {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
-/// Basic block embeddings: [{bb_ptr, Embedding}]
-using BBVecList = std::vector<std::pair<const BasicBlock *, ir2vec::Embedding>>;
-
-/// Instruction embeddings: [{instruction_ptr, Embedding}]
-using InstVecList =
- std::vector<std::pair<const Instruction *, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
+namespace ir2vec {
/// Relation types for triplet generation
enum RelationType {
TypeRelation = 0, ///< Instruction to type relationship
@@ -206,6 +200,8 @@ class IR2VecTool {
return Vocab->isValid();
}
+ /// Generate triplets for a single function
+ /// Returns the maximum relation ID used in this function
TripletResult generateTriplets(const Function &F) const {
if (F.isDeclaration())
return {};
@@ -267,13 +263,12 @@ class IR2VecTool {
return Result;
}
+ /// Get triplets for the entire module
TripletResult generateTriplets() const {
TripletResult Result;
Result.MaxRelation = NextRelation;
- for (const Function &F : M) {
- if (F.isDeclaration())
- continue;
+ for (const Function &F : M.getFunctionDefs()) {
TripletResult FuncResult = generateTriplets(F);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
@@ -283,7 +278,7 @@ class IR2VecTool {
return Result;
}
- /// Generate triplets for the module
+ /// Collect triplets for the module and dump output to stream
/// Output format: MAX_RELATION=N header followed by relationships
void generateTriplets(raw_ostream &OS) const {
auto Result = generateTriplets();
@@ -292,6 +287,8 @@ class IR2VecTool {
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
+ /// Generate entity mappings for the entire vocabulary
+ /// Returns EntityList containing all entity strings
static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
EntityList Result;
@@ -309,52 +306,26 @@ class IR2VecTool {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
- ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
- if (F.isDeclaration())
- return {};
-
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << F.getName() << "\n";
- return {};
- }
- return Emb->getFunctionVector();
+ /// Get function-level embedding using provided embedder
+ ir2vec::Embedding getFunctionEmbedding(const Embedder &Emb) const {
+ return Emb.getFunctionVector();
}
- BBVecList getBBEmbeddings(const Function &F) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
- if (F.isDeclaration())
- return {};
-
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb)
- return {};
-
- BBVecList Result;
+ /// Get basic block embeddings for the given function
+ BBEmbeddingsMap getBBEmbeddings(const Function &F,
+ const Embedder &Emb) const {
+ BBEmbeddingsMap Result;
for (const BasicBlock &BB : F)
- Result.push_back({&BB, Emb->getBBVector(BB)});
-
+ Result.try_emplace(&BB, Emb.getBBVector(BB));
return Result;
}
- InstVecList getInstEmbeddings(const Function &F) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
- if (F.isDeclaration())
- return {};
-
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb)
- return {};
-
- InstVecList Result;
+ /// Get instruction embeddings using provided embedder
+ InstEmbeddingsMap getInstEmbeddings(const Function &F,
+ const Embedder &Emb) const {
+ InstEmbeddingsMap Result;
for (const Instruction &I : instructions(F))
- Result.push_back({&I, Emb->getInstVector(I)});
-
+ Result.try_emplace(&I, Emb.getInstVector(I));
return Result;
}
@@ -366,40 +337,62 @@ class IR2VecTool {
return;
}
- for (const Function &F : M)
+ for (const Function &F : M.getFunctionDefs())
generateEmbeddings(F, OS);
}
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
- assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
+ // Create embedder once for the function
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << F.getName() << "\n";
+ return;
+ }
+
OS << "Function: " << F.getName() << "\n";
switch (Level) {
case EmbeddingLevel::FunctionLevel:
- getFunctionEmbedding(F).print(OS);
+ getFunctionEmbedding(*Emb).print(OS);
break;
- case EmbeddingLevel::BasicBlockLevel:
- for (const auto &[BB, embedding] : getBBEmbeddings(F)) {
- OS << BB->getName() << ":";
- embedding.print(OS);
+ case EmbeddingLevel::BasicBlockLevel: {
+ const auto EmbMap = getBBEmbeddings(F, *Emb);
+ for (const BasicBlock &BB : F) {
+ if (auto It = EmbMap.find(&BB); It != EmbMap.end()) {
+ OS << BB.getName() << ":";
+ It->second.print(OS);
+ }
}
break;
- case EmbeddingLevel::InstructionLevel:
- for (const auto &[I, embedding] : getInstEmbeddings(F)) {
- OS << *I;
- embedding.print(OS);
+ }
+ case EmbeddingLevel::InstructionLevel: {
+ const auto EmbMap = getInstEmbeddings(F, *Emb);
+ for (const Instruction &I : instructions(F)) {
+ if (auto It = EmbMap.find(&I); It != EmbMap.end()) {
+ OS << I;
+ It->second.print(OS);
+ }
}
break;
}
+ }
}
};
+/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -431,20 +424,6 @@ Error processModule(Module &M, raw_ostream &OS) {
} // namespace ir2vec
namespace mir2vec {
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
-/// Machine basic block embeddings: [{mbb_ptr, Embedding}]
-using MBBVecList =
- std::vector<std::pair<const MachineBasicBlock *, ir2vec::Embedding>>;
-
-/// Machine instruction embeddings: [{minstr_ptr, Embedding}]
-using MInstVecList =
- std::vector<std::pair<const MachineInstr *, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
/// Relation types for MIR2Vec triplet generation
enum MIRRelationType {
MIRNextRelation = 0, ///< Sequential instruction relationship
@@ -577,10 +556,7 @@ class MIR2VecTool {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
- for (const Function &F : M) {
- if (F.isDeclaration())
- continue;
-
+ for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
@@ -606,8 +582,7 @@ class MIR2VecTool {
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
- /// Collect entity mappings
- /// Returns EntityList containing all entity strings
+ /// Generate entity mappings for the entire vocabulary
EntityList collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
@@ -642,7 +617,6 @@ class MIR2VecTool {
}
for (const Function &F : M.getFunctionDefs()) {
-
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
@@ -656,58 +630,31 @@ class MIR2VecTool {
/// Get machine function embedding
/// Returns Embedding for the entire machine function
- ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return {};
- }
-
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
- return {};
- }
- return Emb->getMFunctionVector();
+ ir2vec::Embedding getMFunctionEmbedding(const MIREmbedder &Emb) const {
+ return Emb.getMFunctionVector();
}
- /// Get machine basic block embeddings
- /// Returns MBBVecList containing (name, embedding) pairs for all MBBs
- MBBVecList getMBBEmbeddings(MachineFunction &MF) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return {};
- }
-
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb)
- return {};
-
- MBBVecList Result;
+ /// Get machine basic block embeddings using provided embedder
+ /// Returns MachineBlockEmbeddingsMap containing MBB pointer to embedding
+ /// mappings
+ MachineBlockEmbeddingsMap getMBBEmbeddings(MachineFunction &MF,
+ const MIREmbedder &Emb) const {
+ MachineBlockEmbeddingsMap Result;
for (const MachineBasicBlock &MBB : MF)
- Result.push_back({&MBB, Emb->getMBBVector(MBB)});
-
+ Result.try_emplace(&MBB, Emb.getMBBVector(MBB));
return Result;
}
- /// Get machine instruction embeddings
- /// Returns MInstVecList containing (instruction_string, embedding) pairs
- MInstVecList getMInstEmbeddings(MachineFunction &MF) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return {};
- }
-
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb)
- return {};
-
- MInstVecList Result;
+ /// Get machine instruction embeddings using provided embedder
+ /// Returns MachineInstEmbeddingsMap containing MI pointer to embedding
+ /// mappings
+ MachineInstEmbeddingsMap getMInstEmbeddings(MachineFunction &MF,
+ const MIREmbedder &Emb) const {
+ MachineInstEmbeddingsMap Result;
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB)
- Result.push_back({&MI, Emb->getMInstVector(MI)});
+ Result.try_emplace(&MI, Emb.getMInstVector(MI));
}
-
return Result;
}
@@ -718,31 +665,50 @@ class MIR2VecTool {
return;
}
+ // Create embedder once for the machine function
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << MF.getName() << "\n";
+ return;
+ }
+
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
OS << "Function vector: ";
- getMFunctionEmbedding(MF).print(OS);
+ getMFunctionEmbedding(*Emb).print(OS);
break;
- case BasicBlockLevel:
+ case BasicBlockLevel: {
OS << "Basic block vectors:\n";
- for (const auto &[MBB, embedding] : getMBBEmbeddings(MF)) {
- OS << "MBB " << MBB->getName() << ": ";
- embedding.print(OS);
+ const auto MBBEmbMap = getMBBEmbeddings(MF, *Emb);
+ for (const MachineBasicBlock &MBB : MF) {
+ if (auto It = MBBEmbMap.find(&MBB); It != MBBEmbMap.end()) {
+ OS << "MBB " << MBB.getName() << ": ";
+ It->second.print(OS);
+ }
}
break;
- case InstructionLevel:
+ }
+ case InstructionLevel: {
OS << "Instruction vectors:\n";
- for (const auto &[MI, embedding] : getMInstEmbeddings(MF)) {
- OS << *MI << " -> ";
- embedding.print(OS);
+ const auto MInstEmbMap = getMInstEmbeddings(MF, *Emb);
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ if (auto It = MInstEmbMap.find(&MI); It != MInstEmbMap.end()) {
+ OS << MI << " -> ";
+ It->second.print(OS);
+ }
+ }
}
break;
}
+ }
}
+ /// Get the MIR vocabulary instance
const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
};
>From 087a84b52bf14d47b27201c6617413b37844870b Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 10 Dec 2025 14:51:35 +0530
Subject: [PATCH 5/5] Resolving nit comment. To be squashed later
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 5b62142652990..f106ee7693a1e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -201,7 +201,9 @@ class IR2VecTool {
}
/// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
+ /// Returns a TripletResult with:
+ /// - Triplets: vector of all (subject, object, relation) tuples
+ /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
TripletResult generateTriplets(const Function &F) const {
if (F.isDeclaration())
return {};
More information about the llvm-commits
mailing list