[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 01:09:07 PST 2025
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078
>From 834a74d72331e4dd97da84ad703be6b165cb0ae9 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH] Refactoring llvm-ir2vec.cpp for better separation of concerns
in the Tooling classes
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 230 ++++++++++++++++---------
1 file changed, 146 insertions(+), 84 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7b8d3f093a3d1..4cb8742c7654d 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -58,6 +58,7 @@
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
@@ -151,6 +152,24 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
+struct Triplet {
+ unsigned Head = 0; ///< Index of the head entity in the entity list
+ unsigned Tail = 0; ///< Index of the tail entity in the entity list
+ unsigned Relation = 0; ///< Relation type (see RelationType enum)
+};
+
+/// Result structure containing all generated triplets and metadata
+struct TripletResult {
+ unsigned MaxRelation =
+ 0; ///< Highest relation index used (for ArgRelation + N)
+ std::vector<Triplet> Triplets; ///< Collection of all generated triplets
+};
+
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
namespace ir2vec {
/// Relation types for triplet generation
@@ -182,32 +201,18 @@ class IR2VecTool {
return Vocab->isValid();
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(raw_ostream &OS) const {
- unsigned MaxRelation = NextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M) {
- unsigned FuncMaxRelation = generateTriplets(F, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
/// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
+ /// Returns a TripletResult with:
+ /// - Triplets: vector of all (subject, object, relation) tuples
+ /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
+ TripletResult generateTriplets(const Function &F) const {
if (F.isDeclaration())
- return 0;
+ return {};
- unsigned MaxRelation = 1;
+ TripletResult Result;
+ Result.MaxRelation = 0;
+
+ unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
@@ -218,7 +223,7 @@ class IR2VecTool {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs()
<< Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
<< Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
@@ -226,7 +231,7 @@ class IR2VecTool {
}
// Add "Type" relationship
- OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
@@ -237,7 +242,7 @@ class IR2VecTool {
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
- OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
@@ -249,23 +254,57 @@ class IR2VecTool {
++ArgIndex;
}
// Only update MaxRelation if there were operands
- if (ArgIndex > 0) {
+ if (ArgIndex > 0)
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- }
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ Result.MaxRelation = MaxRelation;
+ return Result;
}
- /// Dump entity ID to string mappings
- static void generateEntityMappings(raw_ostream &OS) {
+ /// Get triplets for the entire module
+ TripletResult generateTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ TripletResult FuncResult = generateTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Collect triplets for the module and dump output to stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(raw_ostream &OS) const {
+ auto Result = generateTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+
+ /// Generate entity mappings for the entire vocabulary
+ /// Returns EntityList containing all entity strings
+ static EntityList collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
- OS << EntityLen << "\n";
+ EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+ return Result;
+ }
+
+ /// Dump entity ID to string mappings
+ static void generateEntityMappings(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
/// Generate embeddings for the entire module
@@ -276,19 +315,23 @@ class IR2VecTool {
return;
}
- for (const Function &F : M)
+ for (const Function &F : M.getFunctionDefs())
generateEmbeddings(F, OS);
}
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
// Create embedder for this function
- assert(Vocab->isValid() && "Vocabulary is not valid");
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
if (!Emb) {
WithColor::error(errs(), ToolName)
@@ -312,11 +355,9 @@ class IR2VecTool {
break;
}
case InstructionLevel: {
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
- I.print(OS);
- Emb->getInstVector(I).print(OS);
- }
+ for (const Instruction &I : instructions(F)) {
+ OS << I;
+ Emb->getInstVector(I).print(OS);
}
break;
}
@@ -324,6 +365,7 @@ class IR2VecTool {
}
};
+/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -422,46 +464,20 @@ class MIR2VecTool {
return false;
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(const Module &M, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M.getFunctionDefs()) {
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single machine function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
+ /// Get triplets for a single machine function
+ /// Returns TripletResult containing MaxRelation and vector of Triplets
+ TripletResult generateTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
- return MaxRelation;
+ return Result;
}
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
@@ -473,8 +489,7 @@ class MIR2VecTool {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
- << '\n';
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs()
<< Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -485,7 +500,7 @@ class MIR2VecTool {
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
- OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -497,28 +512,74 @@ class MIR2VecTool {
// Update MaxRelation if there were operands
if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+ Result.MaxRelation =
+ std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
- return MaxRelation;
+ return Result;
}
- /// Generate entity mappings with vocabulary
- void generateEntityMappings(raw_ostream &OS) const {
+ /// Get triplets for the entire module
+ /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+ TripletResult generateTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = generateTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Generate triplets for the module and write to output stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(const Module &M, raw_ostream &OS) const {
+ auto Result = generateTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+
+ /// Generate entity mappings for the entire vocabulary
+ EntityList collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
- return;
+ return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
- OS << EntityCount << "\n";
+ EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+ }
+
+ /// Generate entity mappings and write to output stream
+ void generateEntityMappings(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
/// Generate embeddings for all machine functions in the module
@@ -585,6 +646,7 @@ class MIR2VecTool {
}
}
+ /// Get the MIR vocabulary instance
const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
};
More information about the llvm-commits
mailing list