[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 30 23:01:27 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlgo
Author: Nishant Sachdeva (nishant-sachdeva)
<details>
<summary>Changes</summary>
This patch refactors IR2VecTool and MIR2VecTool classes in llvm-ir2vec.cpp to separate data structure generation from output streaming operations. This is the first step toward creating Python bindings for IR2Vec and MIR2Vec functionality - requested [HERE](https://github.com/llvm/llvm-project/issues/141839)
The current implementation tightly couples data generation with output formatting. Functions like `generateTriplets()` and `generateEmbeddings()` both compute results and write them directly to output streams. This design makes it difficult to:
- Reuse the generated data structures in other contexts (e.g., Python bindings)
- Test data generation independently from output formatting
- Support alternative output formats without duplicating logic
The upcoming python bindings require access to the underlying data structures (triplets, embeddings, entity mappings) without immediately serializing them to text. This refactoring enables that use case while improving code modularity.
All existing unit tests pass without modification, confirming that this refactoring maintains functional equivalence while improving code structure.
---
Full diff: https://github.com/llvm/llvm-project/pull/170078.diff
1 Files Affected:
- (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+290-143)
``````````diff
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7402782bfd404..75cac65c34f90 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -64,6 +64,7 @@
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -151,6 +152,29 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
+/// Entity mappings: entity_id -> entity_name
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: bb_name -> Embedding
+using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Instruction embeddings: instruction_string -> Embedding
+using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Function embeddings: demangled_name -> (actual_name, Embedding)
+using FuncVecList = std::vector<ir2vec::Embedding>;
+
+struct Triplet {
+ unsigned Head;
+ unsigned Tail;
+ unsigned Relation;
+};
+
+struct TripletResult {
+ unsigned MaxRelation;
+ std::vector<Triplet> Triplets;
+};
+
namespace ir2vec {
/// Relation types for triplet generation
@@ -182,32 +206,14 @@ class IR2VecTool {
return Vocab->isValid();
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(raw_ostream &OS) const {
- unsigned MaxRelation = NextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
+ TripletResult getTriplets(const Function &F) const {
+ TripletResult Result;
+ Result.MaxRelation = 0;
- for (const Function &F : M) {
- unsigned FuncMaxRelation = generateTriplets(F, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration())
- return 0;
+ return Result;
- unsigned MaxRelation = 1;
+ unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
@@ -216,56 +222,139 @@ class IR2VecTool {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
- // Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
}
- // Add "Type" relationship
- OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
- // Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
- OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
- // Only update MaxRelation if there were operands
+
if (ArgIndex > 0) {
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
}
+
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ Result.MaxRelation = MaxRelation;
+ return Result;
}
- /// Dump entity ID to string mappings
- static void generateEntityMappings(raw_ostream &OS) {
+ TripletResult getTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M) {
+ TripletResult FuncResult = getTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(raw_ostream &OS) const {
+ auto Result = getTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
- OS << EntityLen << "\n";
+ EntityList Result;
+ Result.reserve(EntityLen);
+
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+ return Result;
+ }
+
+ /// Dump entity ID to string mappings
+ static void generateEntityMappings(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ }
+
+ ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ if (F.isDeclaration())
+ return {};
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ return Emb ? Emb->getFunctionVector() : Embedding{};
+ }
+
+ BBVecList getBBEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const BasicBlock &BB : F)
+ Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+ return Result;
+ }
+
+ InstVecList getInstEmbeddings(const Function &F) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ InstVecList Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const Instruction &I : instructions(F)) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << I;
+ Result.push_back({InstStr, Emb->getInstVector(I)});
+ }
+ return Result;
}
/// Generate embeddings for the entire module
@@ -282,44 +371,31 @@ class IR2VecTool {
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
- // Create embedder for this function
- assert(Vocab->isValid() && "Vocabulary is not valid");
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
OS << "Function: " << F.getName() << "\n";
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- Emb->getFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
+ auto printListLevel = [&](const auto& list, const char* suffix_str) {
+ for (const auto& [name, embedding] : list) {
+ OS << name << suffix_str;
+ embedding.print(OS);
}
- break;
- }
- case InstructionLevel: {
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
- I.print(OS);
- Emb->getInstVector(I).print(OS);
- }
- }
- break;
- }
+ };
+
+ switch (Level) {
+ case EmbeddingLevel::FunctionLevel:
+ getFunctionEmbedding(F).print(OS);
+ break;
+ case EmbeddingLevel::BasicBlockLevel:
+ printListLevel(getBBEmbeddings(F), ":");
+ break;
+ case EmbeddingLevel::InstructionLevel:
+ printListLevel(getInstEmbeddings(F), "");
+ break;
}
}
};
@@ -423,49 +499,22 @@ class MIR2VecTool {
<< "No machine functions found to initialize vocabulary\n";
return false;
}
-
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(const Module &M, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M) {
- if (F.isDeclaration())
- continue;
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single machine function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
+
+ /// Get triplets for a single machine function
+ /// Returns TripletResult containing MaxRelation and vector of Triplets
+ TripletResult getTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
- return MaxRelation;
+ return Result;
}
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
@@ -477,8 +526,7 @@ class MIR2VecTool {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
- << '\n';
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs()
<< Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -489,7 +537,7 @@ class MIR2VecTool {
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
- OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -501,28 +549,82 @@ class MIR2VecTool {
// Update MaxRelation if there were operands
if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+ Result.MaxRelation = std::max(Result.MaxRelation,
+ MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ return Result;
}
+
+ /// Get triplets for the entire module
+ /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+ TripletResult getTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
- /// Generate entity mappings with vocabulary
- void generateEntityMappings(raw_ostream &OS) const {
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = getTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(),
+ FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module and write to output stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(const Module &M, raw_ostream &OS) const {
+ auto Result = getTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets) {
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+ }
+
+ /// Collect entity mappings
+ /// Returns EntityList containing all entity strings
+ EntityList collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
- return;
+ return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
- OS << EntityCount << "\n";
+ EntityList Result;
+ Result.reserve(EntityCount);
+
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+ }
+
+ /// Generate entity mappings and write to output stream
+ void generateEntityMappings(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
/// Generate embeddings for all machine functions in the module
@@ -547,48 +649,93 @@ class MIR2VecTool {
}
}
- /// Generate embeddings for a specific machine function
- void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ /// Get machine function embedding
+ /// Returns Embedding for the entire machine function
+ ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
+ return {};
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
+ return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+ }
+
+ /// Get machine basic block embeddings
+ /// Returns BBVecList containing (name, embedding) pairs for all MBBs
+ BBVecList getMBBEmbeddings(MachineFunction &MF) const {
+ BBVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF)
+ Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+
+ return Result;
+ }
+
+ /// Get machine instruction embeddings
+ /// Returns InstVecList containing (instruction_string, embedding) pairs
+ InstVecList getMInstEmbeddings(MachineFunction &MF) const {
+ InstVecList Result;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return Result;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ std::string InstStr;
+ raw_string_ostream(InstStr) << MI;
+ Result.push_back({InstStr, Emb->getMInstVector(MI)});
+ }
+ }
+
+ return Result;
+ }
+
+ /// Generate embeddings for a specific machine function
+ void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+ auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+ OS << label << ":\n";
+ for (const auto& [name, embedding] : list) {
+ OS << prefix_str << name << suffix_str;
+ embedding.print(OS);
+ }
+ };
+
// Generate embeddings based on the specified level
switch (Level) {
- case FunctionLevel: {
+ case FunctionLevel:
OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
+ getMFunctionEmbedding(MF).print(OS);
break;
- }
- case BasicBlockLevel: {
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
+ case BasicBlockLevel:
+ printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
break;
- }
- case InstructionLevel: {
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
- }
- }
+ case InstructionLevel:
+ printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
break;
}
- }
}
const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
``````````
</details>
https://github.com/llvm/llvm-project/pull/170078
More information about the llvm-commits
mailing list