[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 30 23:02:42 PST 2025
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078
>From 6f57fd49eedeb0a2743c29cbc8d9762add560993 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH] Refactoring llvm-ir2vec.cpp for better separation of concerns
in the Tooling classes
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 433 +++++++++++++++++--------
1 file changed, 290 insertions(+), 143 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7402782bfd404..ba94205193495 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -64,6 +64,7 @@
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -151,6 +152,29 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
+/// Entity mappings: entity_id -> entity_name
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: bb_name -> Embedding
+using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Instruction embeddings: instruction_string -> Embedding
+using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
+
+struct Triplet {
+ unsigned Head;
+ unsigned Tail;
+ unsigned Relation;
+};
+
+struct TripletResult {
+ unsigned MaxRelation;
+ std::vector<Triplet> Triplets;
+};
+
namespace ir2vec {
/// Relation types for triplet generation
@@ -182,32 +206,14 @@ class IR2VecTool {
return Vocab->isValid();
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(raw_ostream &OS) const {
- unsigned MaxRelation = NextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
+ TripletResult getTriplets(const Function &F) const {
+ TripletResult Result;
+ Result.MaxRelation = 0;
- for (const Function &F : M) {
- unsigned FuncMaxRelation = generateTriplets(F, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration())
- return 0;
+ return Result;
- unsigned MaxRelation = 1;
+ unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
@@ -216,56 +222,139 @@ class IR2VecTool {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
- // Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
}
- // Add "Type" relationship
- OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
- // Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
- OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
- // Only update MaxRelation if there were operands
+
if (ArgIndex > 0) {
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
}
+
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ Result.MaxRelation = MaxRelation;
+ return Result;
}
- /// Dump entity ID to string mappings
- static void generateEntityMappings(raw_ostream &OS) {
+ TripletResult getTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M) {
+ TripletResult FuncResult = getTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(raw_ostream &OS) const {
+ auto Result = getTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
- OS << EntityLen << "\n";
+ EntityList Result;
+ Result.reserve(EntityLen);
+
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+ return Result;
+ }
+
+ /// Dump entity ID to string mappings
+ static void generateEntityMappings(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ }
+
+ ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ if (F.isDeclaration())
+ return {};
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ return Emb ? Emb->getFunctionVector() : Embedding{};
+ }
+
+ BBVecList getBBEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const BasicBlock &BB : F)
+ Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+ return Result;
+ }
+
+ InstVecList getInstEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ InstVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const Instruction &I : instructions(F)) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << I;
+ Result.push_back({InstStr, Emb->getInstVector(I)});
+ }
+ return Result;
}
/// Generate embeddings for the entire module
@@ -282,44 +371,31 @@ class IR2VecTool {
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
- // Create embedder for this function
- assert(Vocab->isValid() && "Vocabulary is not valid");
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
OS << "Function: " << F.getName() << "\n";
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- Emb->getFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
+ auto printListLevel = [&](const auto& list, const char* suffix_str) {
+ for (const auto& [name, embedding] : list) {
+ OS << name << suffix_str;
+ embedding.print(OS);
}
- break;
- }
- case InstructionLevel: {
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
- I.print(OS);
- Emb->getInstVector(I).print(OS);
- }
- }
- break;
- }
+ };
+
+ switch (Level) {
+ case EmbeddingLevel::FunctionLevel:
+ getFunctionEmbedding(F).print(OS);
+ break;
+ case EmbeddingLevel::BasicBlockLevel:
+ printListLevel(getBBEmbeddings(F), ":");
+ break;
+ case EmbeddingLevel::InstructionLevel:
+ printListLevel(getInstEmbeddings(F), "");
+ break;
}
}
};
@@ -423,49 +499,22 @@ class MIR2VecTool {
<< "No machine functions found to initialize vocabulary\n";
return false;
}
-
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(const Module &M, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M) {
- if (F.isDeclaration())
- continue;
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single machine function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
+
+ /// Get triplets for a single machine function
+ /// Returns TripletResult containing MaxRelation and vector of Triplets
+ TripletResult getTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
- return MaxRelation;
+ return Result;
}
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
@@ -477,8 +526,7 @@ class MIR2VecTool {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
- << '\n';
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs()
<< Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -489,7 +537,7 @@ class MIR2VecTool {
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
- OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -501,28 +549,82 @@ class MIR2VecTool {
// Update MaxRelation if there were operands
if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+ Result.MaxRelation = std::max(Result.MaxRelation,
+ MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ return Result;
}
+
+ /// Get triplets for the entire module
+ /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+ TripletResult getTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
- /// Generate entity mappings with vocabulary
- void generateEntityMappings(raw_ostream &OS) const {
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = getTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module and write to output stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(const Module &M, raw_ostream &OS) const {
+ auto Result = getTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ /// Collect entity mappings
+ /// Returns EntityList containing all entity strings
+ EntityList collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
- return;
+ return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
- OS << EntityCount << "\n";
+ EntityList Result;
+ Result.reserve(EntityCount);
+
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+ }
+
+ /// Generate entity mappings and write to output stream
+ void generateEntityMappings(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
/// Generate embeddings for all machine functions in the module
@@ -547,48 +649,93 @@ class MIR2VecTool {
}
}
- /// Generate embeddings for a specific machine function
- void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ /// Get machine function embedding
+ /// Returns Embedding for the entire machine function
+ ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
+ return {};
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
+ return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+ }
+
+ /// Get machine basic block embeddings
+ /// Returns BBVecList containing (name, embedding) pairs for all MBBs
+ BBVecList getMBBEmbeddings(MachineFunction &MF) const {
+ BBVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF)
+ Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+
+ return Result;
+ }
+
+ /// Get machine instruction embeddings
+ /// Returns InstVecList containing (instruction_string, embedding) pairs
+ InstVecList getMInstEmbeddings(MachineFunction &MF) const {
+ InstVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << MI;
+ Result.push_back({InstStr, Emb->getMInstVector(MI)});
+ }
+ }
+
+ return Result;
+ }
+
+ /// Generate embeddings for a specific machine function
+ void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+ auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+ OS << label << ":\n";
+ for (const auto& [name, embedding] : list) {
+ OS << prefix_str << name << suffix_str;
+ embedding.print(OS);
+ }
+ };
+
// Generate embeddings based on the specified level
switch (Level) {
- case FunctionLevel: {
+ case FunctionLevel:
OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
+ getMFunctionEmbedding(MF).print(OS);
break;
- }
- case BasicBlockLevel: {
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
+ case BasicBlockLevel:
+ printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
break;
- }
- case InstructionLevel: {
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
- }
- }
+ case InstructionLevel:
+ printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
break;
}
- }
}
const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
More information about the llvm-commits
mailing list