[llvm] [NFC][llvm-ir2vec] llvm_ir2vec.cpp breakup to extract a reusable header for IR2VecTool, and MIR2VecTool classes (PR #172304)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 18 05:51:24 PST 2025
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/172304
>From 7da07971b66566f0be70b5b214ee0d373b6dc11a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 15 Dec 2025 14:58:51 +0530
Subject: [PATCH 1/5] Breaking up llvm-ir2vec.cpp , extracting out tool classes
to prepare a common importable module
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 679 +++++--------------------
llvm/tools/llvm-ir2vec/llvm-ir2vec.h | 534 +++++++++++++++++++
2 files changed, 664 insertions(+), 549 deletions(-)
create mode 100644 llvm/tools/llvm-ir2vec/llvm-ir2vec.h
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 9ab12e36718cd..8b52c385ff524 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,6 +54,7 @@
///
//===----------------------------------------------------------------------===//
+#include "llvm-ir2vec.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
@@ -90,8 +91,6 @@
namespace llvm {
-static const char *ToolName = "llvm-ir2vec";
-
// Common option category for options shared between IR2Vec and MIR2Vec
static cl::OptionCategory CommonCategory("Common Options",
"Options applicable to both IR2Vec "
@@ -135,12 +134,6 @@ static cl::opt<std::string>
cl::value_desc("name"), cl::Optional, cl::init(""),
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
-enum EmbeddingLevel {
- InstructionLevel, // Generate instruction-level embeddings
- BasicBlockLevel, // Generate basic block-level embeddings
- FunctionLevel // Generate function-level embeddings
-};
-
static cl::opt<EmbeddingLevel>
Level("level", cl::desc("Embedding generation level:"),
cl::values(clEnumValN(InstructionLevel, "inst",
@@ -152,219 +145,8 @@ static cl::opt<EmbeddingLevel>
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
-/// Represents a single knowledge graph triplet (Head, Relation, Tail)
-/// where indices reference entities in an EntityList
-struct Triplet {
- unsigned Head = 0; ///< Index of the head entity in the entity list
- unsigned Tail = 0; ///< Index of the tail entity in the entity list
- unsigned Relation = 0; ///< Relation type (see RelationType enum)
-};
-
-/// Result structure containing all generated triplets and metadata
-struct TripletResult {
- unsigned MaxRelation =
- 0; ///< Highest relation index used (for ArgRelation + N)
- std::vector<Triplet> Triplets; ///< Collection of all generated triplets
-};
-
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
namespace ir2vec {
-/// Relation types for triplet generation
-enum RelationType {
- TypeRelation = 0, ///< Instruction to type relationship
- NextRelation = 1, ///< Sequential instruction relationship
- ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for collecting IR triplets and generating embeddings
-class IR2VecTool {
-private:
- Module &M;
- ModuleAnalysisManager MAM;
- const Vocabulary *Vocab = nullptr;
-
-public:
- explicit IR2VecTool(Module &M) : M(M) {}
-
- /// Initialize the IR2Vec vocabulary analysis
- bool initializeVocabulary() {
- // Register and run the IR2Vec vocabulary analysis
- // The vocabulary file path is specified via --ir2vec-vocab-path global
- // option
- MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
- MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
- // This will throw an error if vocab is not found or invalid
- Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
- return Vocab->isValid();
- }
-
- /// Generate triplets for a single function
- /// Returns a TripletResult with:
- /// - Triplets: vector of all (subject, object, relation) tuples
- /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
- TripletResult generateTriplets(const Function &F) const {
- if (F.isDeclaration())
- return {};
-
- TripletResult Result;
- Result.MaxRelation = 0;
-
- unsigned MaxRelation = NextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
-
- for (const BasicBlock &BB : F) {
- for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
- LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
- }
-
- // Add "Type" relationship
- Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
- LLVM_DEBUG(
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
-
- // Add "Arg" relationships
- unsigned ArgIndex = 0;
- for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getIndex(*U.get());
- unsigned RelationID = ArgRelation + ArgIndex;
- Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
- LLVM_DEBUG({
- StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
- Vocabulary::getOperandKind(U.get()));
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
- // Only update MaxRelation if there were operands
- if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- PrevOpcode = Opcode;
- HasPrevOpcode = true;
- }
- }
-
- Result.MaxRelation = MaxRelation;
- return Result;
- }
-
- /// Get triplets for the entire module
- TripletResult generateTriplets() const {
- TripletResult Result;
- Result.MaxRelation = NextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- TripletResult FuncResult = generateTriplets(F);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
- }
-
- return Result;
- }
-
- /// Collect triplets for the module and dump output to stream
- /// Output format: MAX_RELATION=N header followed by relationships
- void writeTripletsToStream(raw_ostream &OS) const {
- auto Result = generateTriplets();
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
- }
-
- /// Generate entity mappings for the entire vocabulary
- /// Returns EntityList containing all entity strings
- static EntityList collectEntityMappings() {
- auto EntityLen = Vocabulary::getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- Result.push_back(Vocabulary::getStringKey(EntityID).str());
- return Result;
- }
-
- /// Dump entity ID to string mappings
- static void writeEntitiesToStream(raw_ostream &OS) {
- auto Entities = collectEntityMappings();
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
- }
-
- /// Generate embeddings for the entire module
- void writeEmbeddingsToStream(raw_ostream &OS) const {
- if (!Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs())
- writeEmbeddingsToStream(F, OS);
- }
-
- /// Generate embeddings for a single function
- void writeEmbeddingsToStream(const Function &F, raw_ostream &OS) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
- if (F.isDeclaration()) {
- OS << "Function " << F.getName() << " is a declaration, skipping.\n";
- return;
- }
-
- // Create embedder for this function
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
- OS << "Function: " << F.getName() << "\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- Emb->getFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
- }
- break;
- }
- case InstructionLevel: {
- for (const Instruction &I : instructions(F)) {
- OS << I;
- Emb->getInstVector(I).print(OS);
- }
- break;
- }
- }
- }
-};
-
/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -379,14 +161,14 @@ Error processModule(Module &M, raw_ostream &OS) {
if (!FunctionName.empty()) {
// Process single function
if (const Function *F = M.getFunction(FunctionName))
- Tool.writeEmbeddingsToStream(*F, OS);
+ Tool.writeEmbeddingsToStream(*F, OS, Level);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
} else {
// Process all functions
- Tool.writeEmbeddingsToStream(OS);
+ Tool.writeEmbeddingsToStream(OS, Level);
}
} else {
// Both triplets and entities use triplet generation
@@ -398,257 +180,151 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
-/// Relation types for MIR2Vec triplet generation
-enum MIRRelationType {
- MIRNextRelation = 0, ///< Sequential instruction relationship
- MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
-};
+/// Setup MIR context from input file
+Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
+ SMDiagnostic Err;
-/// Helper class for MIR2Vec embedding generation
-class MIR2VecTool {
-private:
- MachineModuleInfo &MMI;
- std::unique_ptr<MIRVocabulary> Vocab;
-
-public:
- explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
- /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
- bool initializeVocabulary(const Module &M) {
- MIR2VecVocabProvider Provider(MMI);
- auto VocabOrErr = Provider.getVocabulary(M);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to load MIR2Vec vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
+ auto MIR = createMIRParserFromFile(InputFile, Err, Ctx.Context);
+ if (!MIR) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse MIR file");
}
- /// Initialize vocabulary with layout information only.
- /// This creates a minimal vocabulary with correct layout but no actual
- /// embeddings. Sufficient for generating training data and entity mappings.
- ///
- /// Note: Requires target-specific information from the first machine function
- /// to determine the vocabulary layout (number of opcodes, register classes).
- ///
- /// FIXME: Use --target option to get target info directly, avoiding the need
- /// to parse machine functions for pre-training operations.
- bool initializeVocabularyForLayout(const Module &M) {
- for (const Function &F : M.getFunctionDefs()) {
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF)
- continue;
-
- const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
- const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
- const MachineRegisterInfo &MRI = MF->getRegInfo();
-
- auto VocabOrErr =
- MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to create dummy vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
-
- WithColor::error(errs(), ToolName)
- << "No machine functions found to initialize vocabulary\n";
- return false;
- }
-
- /// Get triplets for a single machine function
- /// Returns TripletResult containing MaxRelation and vector of Triplets
- TripletResult generateTriplets(const MachineFunction &MF) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "MIR Vocabulary must be initialized for triplet generation.\n";
- return Result;
- }
-
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- // Skip debug instructions
- if (MI.isDebugInstr())
- continue;
-
- // Get opcode entity ID
- unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
- LLVM_DEBUG(dbgs()
- << Vocab->getStringKey(PrevOpcode) << '\t'
- << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
- }
-
- // Add "Arg" relationships for operands
- unsigned ArgIndex = 0;
- for (const MachineOperand &MO : MI.operands()) {
- auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
- unsigned RelationID = MIRArgRelation + ArgIndex;
- Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
- LLVM_DEBUG({
- std::string OperandStr = Vocab->getStringKey(OperandID);
- dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
- << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
+ auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
+ StringRef OldDLStr) -> std::optional<std::string> {
+ std::string IRTargetTriple = DataLayoutTargetTriple.str();
+ Triple TheTriple = Triple(IRTargetTriple);
+ if (TheTriple.getTriple().empty())
+ TheTriple.setTriple(sys::getDefaultTargetTriple());
- // Update MaxRelation if there were operands
- if (ArgIndex > 0)
- Result.MaxRelation =
- std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
- PrevOpcode = OpcodeID;
- HasPrevOpcode = true;
- }
- }
-
- return Result;
- }
-
- /// Get triplets for the entire module
- /// Returns TripletResult containing aggregated MaxRelation and all Triplets
- TripletResult generateTriplets(const Module &M) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- TripletResult FuncResult = generateTriplets(*MF);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
+ auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
+ if (!TMOrErr) {
+ Err.print(ToolName, errs());
+ exit(1); // Match original behavior
}
+ Ctx.TM = std::move(*TMOrErr);
+ return Ctx.TM->createDataLayout().getStringRepresentation();
+ };
- return Result;
+ Ctx.M = MIR->parseIRModule(SetDataLayout);
+ if (!Ctx.M) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse IR module");
}
- /// Collect triplets for the module and write to output stream
- /// Output format: MAX_RELATION=N header followed by relationships
- void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
- auto Result = generateTriplets(M);
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ Ctx.MMI = std::make_unique<MachineModuleInfo>(Ctx.TM.get());
+ if (!Ctx.MMI || MIR->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse machine functions");
}
- /// Generate entity mappings for the entire vocabulary
- EntityList collectEntityMappings() const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary must be initialized for entity mappings.\n";
- return {};
- }
-
- const unsigned EntityCount = Vocab->getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- Result.push_back(Vocab->getStringKey(EntityID));
+ return Error::success();
+}
- return Result;
- }
+/// Generic vocabulary initialization and processing
+template <typename ProcessFunc>
+Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
+ bool useLayoutVocab, ProcessFunc processFn) {
+ MIR2VecTool Tool(*Ctx.MMI);
- /// Generate entity mappings and write to output stream
- void writeEntitiesToStream(raw_ostream &OS) const {
- auto Entities = collectEntityMappings();
- if (Entities.empty())
- return;
+ // Initialize appropriate vocabulary type
+ bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(*Ctx.M)
+ : Tool.initializeVocabulary(*Ctx.M);
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ if (!success) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to initialize MIR2Vec vocabulary"
+ << (useLayoutVocab ? " for layout" : "") << ".\n";
+ return createStringError(errc::invalid_argument,
+ "Vocabulary initialization failed");
}
- /// Generate embeddings for all machine functions in the module
- void writeEmbeddingsToStream(const Module &M, raw_ostream &OS) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs()) {
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
+ assert(Tool.getVocabulary() &&
+ "MIR2Vec vocabulary should be initialized at this point");
- writeEmbeddingsToStream(*MF, OS);
- }
- }
+ LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+ << "Vocabulary dimension: "
+ << Tool.getVocabulary()->getDimension() << "\n"
+ << "Vocabulary size: "
+ << Tool.getVocabulary()->getCanonicalSize() << "\n");
- /// Generate embeddings for a specific machine function
- void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
+ // Execute the specific processing logic
+ return processFn(Tool);
+}
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
- return;
- }
+/// Process module for triplet generation
+Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+ [&](MIR2VecTool &Tool) -> Error {
+ Tool.writeTripletsToStream(*Ctx.M, OS);
+ return Error::success();
+ });
+}
- OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+/// Process module for entity generation
+Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+ [&](MIR2VecTool &Tool) -> Error {
+ Tool.writeEntitiesToStream(OS);
+ return Error::success();
+ });
+}
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
- break;
- }
- case InstructionLevel: {
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
+/// Process module for embedding generation
+Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(
+ Ctx, OS, /*useLayoutVocab=*/false, [&](MIR2VecTool &Tool) -> Error {
+ if (!FunctionName.empty()) {
+ // Process single function
+ Function *F = Ctx.M->getFunction(FunctionName);
+ if (!F) {
+ WithColor::error(errs(), ToolName)
+ << "Function '" << FunctionName << "' not found\n";
+ return createStringError(errc::invalid_argument,
+ "Function not found");
+ }
+
+ MachineFunction *MF = Ctx.MMI->getMachineFunction(*F);
+ if (!MF) {
+ WithColor::error(errs(), ToolName)
+ << "No MachineFunction for " << FunctionName << "\n";
+ return createStringError(errc::invalid_argument,
+ "No MachineFunction");
+ }
+
+ Tool.writeEmbeddingsToStream(*MF, OS, Level);
+ } else {
+ // Process all functions
+ Tool.writeEmbeddingsToStream(*Ctx.M, OS, Level);
}
- }
- break;
- }
- }
- }
+ return Error::success();
+ });
+}
- /// Get the MIR vocabulary instance
- const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
-};
+/// Main entry point for MIR processing
+Error processModule(const std::string &InputFile, raw_ostream &OS) {
+ MIRContext Ctx;
+
+ // Setup MIR context (parse file, setup target machine, etc.)
+ if (auto Err = setupMIRContext(InputFile, Ctx))
+ return Err;
+
+ // Process based on subcommand
+ if (TripletsSubCmd)
+ return processModuleForTriplets(Ctx, OS);
+ else if (EntitiesSubCmd)
+ return processModuleForEntities(Ctx, OS);
+ else if (EmbeddingsSubCmd)
+ return processModuleForEmbeddings(Ctx, OS);
+ else {
+ WithColor::error(errs(), ToolName)
+ << "Please specify a subcommand: triplets, entities, or embeddings\n";
+ return createStringError(errc::invalid_argument, "No subcommand specified");
+ }
+}
} // namespace mir2vec
@@ -712,105 +388,10 @@ int main(int argc, char **argv) {
InitializeAllAsmPrinters();
static codegen::RegisterCodeGenFlags CGF;
- // Parse MIR input file
- SMDiagnostic Err;
- LLVMContext Context;
- std::unique_ptr<TargetMachine> TM;
-
- auto MIR = createMIRParserFromFile(InputFilename, Err, Context);
- if (!MIR) {
- Err.print(ToolName, errs());
- return 1;
- }
-
- auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
- StringRef OldDLStr) -> std::optional<std::string> {
- std::string IRTargetTriple = DataLayoutTargetTriple.str();
- Triple TheTriple = Triple(IRTargetTriple);
- if (TheTriple.getTriple().empty())
- TheTriple.setTriple(sys::getDefaultTargetTriple());
- auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
- if (!TMOrErr) {
- Err.print(ToolName, errs());
- exit(1);
- }
- TM = std::move(*TMOrErr);
- return TM->createDataLayout().getStringRepresentation();
- };
-
- std::unique_ptr<Module> M = MIR->parseIRModule(SetDataLayout);
- if (!M) {
- Err.print(ToolName, errs());
- return 1;
- }
-
- // Parse machine functions
- auto MMI = std::make_unique<MachineModuleInfo>(TM.get());
- if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) {
- Err.print(ToolName, errs());
- return 1;
- }
-
- // Create MIR2Vec tool
- MIR2VecTool Tool(*MMI);
-
- // Initialize vocabulary. For triplet/entity generation, only layout is
- // needed For embedding generation, the full vocabulary is needed.
- //
- // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires
- // target-specific information for generating the vocabulary layout. So, we
- // always initialize the vocabulary in this case.
- if (TripletsSubCmd || EntitiesSubCmd) {
- if (!Tool.initializeVocabularyForLayout(*M)) {
- WithColor::error(errs(), ToolName)
- << "Failed to initialize MIR2Vec vocabulary for layout.\n";
- return 1;
- }
- } else {
- if (!Tool.initializeVocabulary(*M)) {
- WithColor::error(errs(), ToolName)
- << "Failed to initialize MIR2Vec vocabulary.\n";
- return 1;
- }
- }
- assert(Tool.getVocabulary() &&
- "MIR2Vec vocabulary should be initialized at this point");
- LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
- << "Vocabulary dimension: "
- << Tool.getVocabulary()->getDimension() << "\n"
- << "Vocabulary size: "
- << Tool.getVocabulary()->getCanonicalSize() << "\n");
-
- // Handle subcommands
- if (TripletsSubCmd) {
- Tool.writeTripletsToStream(*M, OS);
- } else if (EntitiesSubCmd) {
- Tool.writeEntitiesToStream(OS);
- } else if (EmbeddingsSubCmd) {
- if (!FunctionName.empty()) {
- // Process single function
- Function *F = M->getFunction(FunctionName);
- if (!F) {
- WithColor::error(errs(), ToolName)
- << "Function '" << FunctionName << "' not found\n";
- return 1;
- }
-
- MachineFunction *MF = MMI->getMachineFunction(*F);
- if (!MF) {
- WithColor::error(errs(), ToolName)
- << "No MachineFunction for " << FunctionName << "\n";
- return 1;
- }
-
- Tool.writeEmbeddingsToStream(*MF, OS);
- } else {
- // Process all functions
- Tool.writeEmbeddingsToStream(*M, OS);
- }
- } else {
- WithColor::error(errs(), ToolName)
- << "Please specify a subcommand: triplets, entities, or embeddings\n";
+ if (Error Err = mir2vec::processModule(InputFilename, OS)) {
+ handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
+ WithColor::error(errs(), ToolName) << EIB.message() << "\n";
+ });
return 1;
}
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
new file mode 100644
index 0000000000000..56fd834d380a8
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -0,0 +1,534 @@
+//===- llvm-ir2vec.h - IR2Vec/MIR2Vec Tool Classes ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the IR2VecTool and MIR2VecTool class definitions and
+/// implementations for the llvm-ir2vec embedding generation tool.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetMachine.h"
+#include <memory>
+#include <string>
+#include <vector>
+
+#define DEBUG_TYPE "ir2vec"
+
+namespace llvm {
+
+/// Tool name for error reporting
+static const char *ToolName = "llvm-ir2vec";
+
+/// Specifies the granularity at which embeddings are generated.
+enum EmbeddingLevel {
+ InstructionLevel, // Generate instruction-level embeddings
+ BasicBlockLevel, // Generate basic block-level embeddings
+ FunctionLevel // Generate function-level embeddings
+};
+
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
+struct Triplet {
+ unsigned Head = 0; ///< Index of the head entity in the entity list
+ unsigned Tail = 0; ///< Index of the tail entity in the entity list
+ unsigned Relation = 0; ///< Relation type (see RelationType enum)
+};
+
+/// Result structure containing all generated triplets and metadata
+struct TripletResult {
+ unsigned MaxRelation =
+ 0; ///< Highest relation index used (for ArgRelation + N)
+ std::vector<Triplet> Triplets; ///< Collection of all generated triplets
+};
+
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+namespace ir2vec {
+
+/// Relation types for triplet generation
+enum RelationType {
+ TypeRelation = 0, ///< Instruction to type relationship
+ NextRelation = 1, ///< Sequential instruction relationship
+ ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
+};
+
+/// Helper class for collecting IR triplets and generating embeddings
+class IR2VecTool {
+private:
+ Module &M;
+ ModuleAnalysisManager MAM;
+ const Vocabulary *Vocab = nullptr;
+
+public:
+ explicit IR2VecTool(Module &M) : M(M) {}
+
+ /// Initialize the IR2Vec vocabulary analysis
+ bool initializeVocabulary() {
+ // Register and run the IR2Vec vocabulary analysis
+ // The vocabulary file path is specified via --ir2vec-vocab-path global
+ // option
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+ // This will throw an error if vocab is not found or invalid
+ Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+ return Vocab->isValid();
+ }
+
+ /// Generate triplets for a single function
+ /// Returns a TripletResult with:
+ /// - Triplets: vector of all (subject, object, relation) tuples
+ /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
+ TripletResult generateTriplets(const Function &F) const {
+ if (F.isDeclaration())
+ return {};
+
+ TripletResult Result;
+ Result.MaxRelation = 0;
+
+ unsigned MaxRelation = NextRelation;
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
+ for (const BasicBlock &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+ LLVM_DEBUG(dbgs()
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
+ }
+
+ // Add "Type" relationship
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+ LLVM_DEBUG(
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
+
+ // Add "Arg" relationships
+ unsigned ArgIndex = 0;
+ for (const Use &U : I.operands()) {
+ unsigned OperandID = Vocabulary::getIndex(*U.get());
+ unsigned RelationID = ArgRelation + ArgIndex;
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+ LLVM_DEBUG({
+ StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+ Vocabulary::getOperandKind(U.get()));
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+ // Only update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+ PrevOpcode = Opcode;
+ HasPrevOpcode = true;
+ }
+ }
+
+ Result.MaxRelation = MaxRelation;
+ return Result;
+ }
+
+ /// Get triplets for the entire module
+ TripletResult generateTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ TripletResult FuncResult = generateTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Collect triplets for the module and dump output to stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void writeTripletsToStream(raw_ostream &OS) const {
+ auto Result = generateTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+
+ /// Generate entity mappings for the entire vocabulary
+ /// Returns EntityList containing all entity strings
+ static EntityList collectEntityMappings() {
+ auto EntityLen = Vocabulary::getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+ return Result;
+ }
+
+ /// Dump entity ID to string mappings
+ static void writeEntitiesToStream(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ }
+
+ /// Generate embeddings for the entire module
+ void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const {
+ if (!Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs())
+ writeEmbeddingsToStream(F, OS, Level);
+ }
+
+ /// Generate embeddings for a single function
+ void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+ if (F.isDeclaration()) {
+ OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+ return;
+ }
+
+ // Create embedder for this function
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for function " << F.getName() << "\n";
+ return;
+ }
+
+ OS << "Function: " << F.getName() << "\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ Emb->getFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ for (const BasicBlock &BB : F) {
+ OS << BB.getName() << ":";
+ Emb->getBBVector(BB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ for (const Instruction &I : instructions(F)) {
+ OS << I;
+ Emb->getInstVector(I).print(OS);
+ }
+ break;
+ }
+ }
+};
+
+} // namespace ir2vec
+
+namespace mir2vec {
+
+/// Relation types for MIR2Vec triplet generation
+enum MIRRelationType {
+ MIRNextRelation = 0, ///< Sequential instruction relationship
+ MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
+};
+
+/// Helper class for MIR2Vec embedding generation
+class MIR2VecTool {
+private:
+ MachineModuleInfo &MMI;
+ std::unique_ptr<MIRVocabulary> Vocab;
+
+public:
+ explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+ /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
+ bool initializeVocabulary(const Module &M) {
+ MIR2VecVocabProvider Provider(MMI);
+ auto VocabOrErr = Provider.getVocabulary(M);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to load MIR2Vec vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+ }
+
+ /// Initialize vocabulary with layout information only.
+ /// This creates a minimal vocabulary with correct layout but no actual
+ /// embeddings. Sufficient for generating training data and entity mappings.
+ ///
+ /// Note: Requires target-specific information from the first machine function
+ /// to determine the vocabulary layout (number of opcodes, register classes).
+ ///
+ /// FIXME: Use --target option to get target info directly, avoiding the need
+ /// to parse machine functions for pre-training operations.
+ bool initializeVocabularyForLayout(const Module &M) {
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF)
+ continue;
+
+ const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+ const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+ const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create dummy vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+ }
+
+ WithColor::error(errs(), ToolName)
+ << "No machine functions found to initialize vocabulary\n";
+ return false;
+ }
+
+ /// Get triplets for a single machine function
+ /// Returns TripletResult containing MaxRelation and vector of Triplets
+ TripletResult generateTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "MIR Vocabulary must be initialized for triplet generation.\n";
+ return Result;
+ }
+
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions
+ if (MI.isDebugInstr())
+ continue;
+
+ // Get opcode entity ID
+ unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+ LLVM_DEBUG(dbgs()
+ << Vocab->getStringKey(PrevOpcode) << '\t'
+ << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+ }
+
+ // Add "Arg" relationships for operands
+ unsigned ArgIndex = 0;
+ for (const MachineOperand &MO : MI.operands()) {
+ auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+ unsigned RelationID = MIRArgRelation + ArgIndex;
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+ LLVM_DEBUG({
+ std::string OperandStr = Vocab->getStringKey(OperandID);
+ dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
+ << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+
+ // Update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ Result.MaxRelation =
+ std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+ PrevOpcode = OpcodeID;
+ HasPrevOpcode = true;
+ }
+ }
+
+ return Result;
+ }
+
+ /// Get triplets for the entire module
+ /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+ TripletResult generateTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = generateTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+ }
+
+ /// Collect triplets for the module and write to output stream
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
+ auto Result = generateTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+ }
+
+ /// Generate entity mappings for the entire vocabulary
+ EntityList collectEntityMappings() const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary must be initialized for entity mappings.\n";
+ return {};
+ }
+
+ const unsigned EntityCount = Vocab->getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+ }
+
+ /// Generate entity mappings and write to output stream
+ void writeEntitiesToStream(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+ }
+
+ /// Generate embeddings for all machine functions in the module
+ void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ writeEmbeddingsToStream(*MF, OS, Level);
+ }
+ }
+
+ /// Generate embeddings for a specific machine function
+ void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << MF.getName() << "\n";
+ return;
+ }
+
+ OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ OS << "Function vector: ";
+ Emb->getMFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ OS << "Basic block vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ OS << "MBB " << MBB.getName() << ": ";
+ Emb->getMBBVector(MBB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ OS << "Instruction vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ OS << MI << " -> ";
+ Emb->getMInstVector(MI).print(OS);
+ }
+ }
+ break;
+ }
+ }
+
+ /// Get the MIR vocabulary instance
+ const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+};
+
+/// Helper structure to hold MIR context
+struct MIRContext {
+ LLVMContext Context; // CRITICAL: Must be first for proper destruction order
+ std::unique_ptr<Module> M;
+ std::unique_ptr<MachineModuleInfo> MMI;
+ std::unique_ptr<TargetMachine> TM;
+};
+
+} // namespace mir2vec
+
+} // namespace llvm
+
+#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
\ No newline at end of file
>From 41e4e13eafc0621a92ed053371aa2bb7112396d0 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 12:47:49 +0530
Subject: [PATCH 2/5] Nit commit - header guard naming convention compliance
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index 56fd834d380a8..cb4572996b1f4 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -12,8 +12,8 @@
///
//===----------------------------------------------------------------------===//
-#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
-#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
+#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
@@ -531,4 +531,4 @@ struct MIRContext {
} // namespace llvm
-#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
\ No newline at end of file
+#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
\ No newline at end of file
>From 0978b4c957717df4171135601b61392af3784880 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 16:01:18 +0530
Subject: [PATCH 3/5] Work Commit - Moved function definitions out of header
file
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 378 +++++++++++++++++++++++++
llvm/tools/llvm-ir2vec/llvm-ir2vec.h | 377 ++----------------------
2 files changed, 400 insertions(+), 355 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 8b52c385ff524..a515797c0c505 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -147,6 +147,169 @@ static cl::opt<EmbeddingLevel>
namespace ir2vec {
+IR2VecTool::IR2VecTool(Module &M) : M(M) {}
+
+bool IR2VecTool::initializeVocabulary() {
+ // Register and run the IR2Vec vocabulary analysis
+ // The vocabulary file path is specified via --ir2vec-vocab-path global
+ // option
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+ // This will throw an error if vocab is not found or invalid
+ Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+ return Vocab->isValid();
+}
+
+TripletResult IR2VecTool::generateTriplets(const Function &F) const {
+ if (F.isDeclaration())
+ return {};
+
+ TripletResult Result;
+ Result.MaxRelation = 0;
+
+ unsigned MaxRelation = NextRelation;
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
+ for (const BasicBlock &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+ unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+ LLVM_DEBUG(dbgs()
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
+ }
+
+ // Add "Type" relationship
+ Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+ LLVM_DEBUG(
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
+
+ // Add "Arg" relationships
+ unsigned ArgIndex = 0;
+ for (const Use &U : I.operands()) {
+ unsigned OperandID = Vocabulary::getIndex(*U.get());
+ unsigned RelationID = ArgRelation + ArgIndex;
+ Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+ LLVM_DEBUG({
+ StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+ Vocabulary::getOperandKind(U.get()));
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+ // Only update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+ PrevOpcode = Opcode;
+ HasPrevOpcode = true;
+ }
+ }
+
+ Result.MaxRelation = MaxRelation;
+ return Result;
+}
+
+TripletResult IR2VecTool::generateTriplets() const {
+ TripletResult Result;
+ Result.MaxRelation = NextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ TripletResult FuncResult = generateTriplets(F);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+}
+
+void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
+ auto Result = generateTriplets();
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList IR2VecTool::collectEntityMappings() {
+ auto EntityLen = Vocabulary::getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+ Result.push_back(Vocabulary::getStringKey(EntityID).str());
+ return Result;
+}
+
+void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
+ auto Entities = collectEntityMappings();
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs())
+ writeEmbeddingsToStream(F, OS, Level);
+}
+
+void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab || !Vocab->isValid()) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+ if (F.isDeclaration()) {
+ OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+ return;
+ }
+
+ // Create embedder for this function
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for function " << F.getName() << "\n";
+ return;
+ }
+
+ OS << "Function: " << F.getName() << "\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ Emb->getFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ for (const BasicBlock &BB : F) {
+ OS << BB.getName() << ":";
+ Emb->getBBVector(BB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ for (const Instruction &I : instructions(F)) {
+ OS << I;
+ Emb->getInstVector(I).print(OS);
+ }
+ break;
+ }
+}
+
/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -180,6 +343,221 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
+MIR2VecTool::MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+bool MIR2VecTool::initializeVocabulary(const Module &M) {
+ MIR2VecVocabProvider Provider(MMI);
+ auto VocabOrErr = Provider.getVocabulary(M);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to load MIR2Vec vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+}
+
+bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF)
+ continue;
+
+ const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+ const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+ const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create dummy vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+ }
+
+ WithColor::error(errs(), ToolName)
+ << "No machine functions found to initialize vocabulary\n";
+ return false;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "MIR Vocabulary must be initialized for triplet generation.\n";
+ return Result;
+ }
+
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions
+ if (MI.isDebugInstr())
+ continue;
+
+ // Get opcode entity ID
+ unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+ LLVM_DEBUG(dbgs()
+ << Vocab->getStringKey(PrevOpcode) << '\t'
+ << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+ }
+
+ // Add "Arg" relationships for operands
+ unsigned ArgIndex = 0;
+ for (const MachineOperand &MO : MI.operands()) {
+ auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+ unsigned RelationID = MIRArgRelation + ArgIndex;
+ Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+ LLVM_DEBUG({
+ std::string OperandStr = Vocab->getStringKey(OperandID);
+ dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
+ << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+
+ // Update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ Result.MaxRelation =
+ std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+ PrevOpcode = OpcodeID;
+ HasPrevOpcode = true;
+ }
+ }
+
+ return Result;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
+ TripletResult Result;
+ Result.MaxRelation = MIRNextRelation;
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ TripletResult FuncResult = generateTriplets(*MF);
+ Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+ Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+ FuncResult.Triplets.end());
+ }
+
+ return Result;
+}
+
+void MIR2VecTool::writeTripletsToStream(const Module &M,
+ raw_ostream &OS) const {
+ auto Result = generateTriplets(M);
+ OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+ for (const auto &T : Result.Triplets)
+ OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList MIR2VecTool::collectEntityMappings() const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary must be initialized for entity mappings.\n";
+ return {};
+ }
+
+ const unsigned EntityCount = Vocab->getCanonicalSize();
+ EntityList Result;
+ for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+ Result.push_back(Vocab->getStringKey(EntityID));
+
+ return Result;
+}
+
+void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
+ auto Entities = collectEntityMappings();
+ if (Entities.empty())
+ return;
+
+ OS << Entities.size() << "\n";
+ for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+ OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M.getFunctionDefs()) {
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ writeEmbeddingsToStream(*MF, OS, Level);
+ }
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+ EmbeddingLevel Level) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+ return;
+ }
+
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+ if (!Emb) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create embedder for " << MF.getName() << "\n";
+ return;
+ }
+
+ OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel:
+ OS << "Function vector: ";
+ Emb->getMFunctionVector().print(OS);
+ break;
+ case BasicBlockLevel:
+ OS << "Basic block vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ OS << "MBB " << MBB.getName() << ": ";
+ Emb->getMBBVector(MBB).print(OS);
+ }
+ break;
+ case InstructionLevel:
+ OS << "Instruction vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ OS << MI << " -> ";
+ Emb->getMInstVector(MI).print(OS);
+ }
+ }
+ break;
+ }
+}
+
+const MIRVocabulary *MIR2VecTool::getVocabulary() const { return Vocab.get(); }
+
/// Setup MIR context from input file
Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
SMDiagnostic Err;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index cb4572996b1f4..9bcce9741c917 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
///
/// \file
-/// This file contains the IR2VecTool and MIR2VecTool class definitions and
-/// implementations for the llvm-ir2vec embedding generation tool.
+/// This file contains the IR2VecTool and MIR2VecTool class definitions for
+/// the llvm-ir2vec embedding generation tool.
///
//===----------------------------------------------------------------------===//
@@ -90,180 +90,37 @@ class IR2VecTool {
const Vocabulary *Vocab = nullptr;
public:
- explicit IR2VecTool(Module &M) : M(M) {}
+ explicit IR2VecTool(Module &M);
/// Initialize the IR2Vec vocabulary analysis
- bool initializeVocabulary() {
- // Register and run the IR2Vec vocabulary analysis
- // The vocabulary file path is specified via --ir2vec-vocab-path global
- // option
- MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
- MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
- // This will throw an error if vocab is not found or invalid
- Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
- return Vocab->isValid();
- }
+ bool initializeVocabulary();
/// Generate triplets for a single function
/// Returns a TripletResult with:
/// - Triplets: vector of all (subject, object, relation) tuples
/// - MaxRelation: highest Arg relation ID used, or NextRelation if none
- TripletResult generateTriplets(const Function &F) const {
- if (F.isDeclaration())
- return {};
-
- TripletResult Result;
- Result.MaxRelation = 0;
-
- unsigned MaxRelation = NextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
-
- for (const BasicBlock &BB : F) {
- for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
- LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
- }
-
- // Add "Type" relationship
- Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
- LLVM_DEBUG(
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
-
- // Add "Arg" relationships
- unsigned ArgIndex = 0;
- for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getIndex(*U.get());
- unsigned RelationID = ArgRelation + ArgIndex;
- Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
- LLVM_DEBUG({
- StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
- Vocabulary::getOperandKind(U.get()));
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
- // Only update MaxRelation if there were operands
- if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- PrevOpcode = Opcode;
- HasPrevOpcode = true;
- }
- }
-
- Result.MaxRelation = MaxRelation;
- return Result;
- }
+ TripletResult generateTriplets(const Function &F) const;
/// Get triplets for the entire module
- TripletResult generateTriplets() const {
- TripletResult Result;
- Result.MaxRelation = NextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- TripletResult FuncResult = generateTriplets(F);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
- }
-
- return Result;
- }
+ TripletResult generateTriplets() const;
/// Collect triplets for the module and dump output to stream
/// Output format: MAX_RELATION=N header followed by relationships
- void writeTripletsToStream(raw_ostream &OS) const {
- auto Result = generateTriplets();
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
- }
+ void writeTripletsToStream(raw_ostream &OS) const;
/// Generate entity mappings for the entire vocabulary
/// Returns EntityList containing all entity strings
- static EntityList collectEntityMappings() {
- auto EntityLen = Vocabulary::getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- Result.push_back(Vocabulary::getStringKey(EntityID).str());
- return Result;
- }
+ static EntityList collectEntityMappings();
/// Dump entity ID to string mappings
- static void writeEntitiesToStream(raw_ostream &OS) {
- auto Entities = collectEntityMappings();
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
- }
+ static void writeEntitiesToStream(raw_ostream &OS);
/// Generate embeddings for the entire module
- void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const {
- if (!Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs())
- writeEmbeddingsToStream(F, OS, Level);
- }
+ void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
/// Generate embeddings for a single function
void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
- if (F.isDeclaration()) {
- OS << "Function " << F.getName() << " is a declaration, skipping.\n";
- return;
- }
-
- // Create embedder for this function
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
- OS << "Function: " << F.getName() << "\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel:
- Emb->getFunctionVector().print(OS);
- break;
- case BasicBlockLevel:
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
- }
- break;
- case InstructionLevel:
- for (const Instruction &I : instructions(F)) {
- OS << I;
- Emb->getInstVector(I).print(OS);
- }
- break;
- }
- }
+ EmbeddingLevel Level) const;
};
} // namespace ir2vec
@@ -283,21 +140,10 @@ class MIR2VecTool {
std::unique_ptr<MIRVocabulary> Vocab;
public:
- explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+ explicit MIR2VecTool(MachineModuleInfo &MMI);
/// Initialize MIR2Vec vocabulary from file (for embeddings generation)
- bool initializeVocabulary(const Module &M) {
- MIR2VecVocabProvider Provider(MMI);
- auto VocabOrErr = Provider.getVocabulary(M);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to load MIR2Vec vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
+ bool initializeVocabulary(const Module &M);
/// Initialize vocabulary with layout information only.
/// This creates a minimal vocabulary with correct layout but no actual
@@ -308,215 +154,36 @@ class MIR2VecTool {
///
/// FIXME: Use --target option to get target info directly, avoiding the need
/// to parse machine functions for pre-training operations.
- bool initializeVocabularyForLayout(const Module &M) {
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF)
- continue;
-
- const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
- const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
- const MachineRegisterInfo &MRI = MF->getRegInfo();
-
- auto VocabOrErr =
- MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to create dummy vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
-
- WithColor::error(errs(), ToolName)
- << "No machine functions found to initialize vocabulary\n";
- return false;
- }
+ bool initializeVocabularyForLayout(const Module &M);
/// Get triplets for a single machine function
/// Returns TripletResult containing MaxRelation and vector of Triplets
- TripletResult generateTriplets(const MachineFunction &MF) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "MIR Vocabulary must be initialized for triplet generation.\n";
- return Result;
- }
-
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- // Skip debug instructions
- if (MI.isDebugInstr())
- continue;
-
- // Get opcode entity ID
- unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
- LLVM_DEBUG(dbgs()
- << Vocab->getStringKey(PrevOpcode) << '\t'
- << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
- }
-
- // Add "Arg" relationships for operands
- unsigned ArgIndex = 0;
- for (const MachineOperand &MO : MI.operands()) {
- auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
- unsigned RelationID = MIRArgRelation + ArgIndex;
- Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
- LLVM_DEBUG({
- std::string OperandStr = Vocab->getStringKey(OperandID);
- dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
- << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
-
- // Update MaxRelation if there were operands
- if (ArgIndex > 0)
- Result.MaxRelation =
- std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
- PrevOpcode = OpcodeID;
- HasPrevOpcode = true;
- }
- }
-
- return Result;
- }
+ TripletResult generateTriplets(const MachineFunction &MF) const;
/// Get triplets for the entire module
/// Returns TripletResult containing aggregated MaxRelation and all Triplets
- TripletResult generateTriplets(const Module &M) const {
- TripletResult Result;
- Result.MaxRelation = MIRNextRelation;
-
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- TripletResult FuncResult = generateTriplets(*MF);
- Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
- Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
- FuncResult.Triplets.end());
- }
-
- return Result;
- }
+ TripletResult generateTriplets(const Module &M) const;
/// Collect triplets for the module and write to output stream
/// Output format: MAX_RELATION=N header followed by relationships
- void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
- auto Result = generateTriplets(M);
- OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
- for (const auto &T : Result.Triplets)
- OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
- }
+ void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
/// Generate entity mappings for the entire vocabulary
- EntityList collectEntityMappings() const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary must be initialized for entity mappings.\n";
- return {};
- }
-
- const unsigned EntityCount = Vocab->getCanonicalSize();
- EntityList Result;
- for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- Result.push_back(Vocab->getStringKey(EntityID));
-
- return Result;
- }
+ EntityList collectEntityMappings() const;
/// Generate entity mappings and write to output stream
- void writeEntitiesToStream(raw_ostream &OS) const {
- auto Entities = collectEntityMappings();
- if (Entities.empty())
- return;
-
- OS << Entities.size() << "\n";
- for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
- OS << Entities[EntityID] << '\t' << EntityID << '\n';
- }
+ void writeEntitiesToStream(raw_ostream &OS) const;
/// Generate embeddings for all machine functions in the module
void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
-
- for (const Function &F : M.getFunctionDefs()) {
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- writeEmbeddingsToStream(*MF, OS, Level);
- }
- }
+ EmbeddingLevel Level) const;
/// Generate embeddings for a specific machine function
void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
- EmbeddingLevel Level) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
-
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
- return;
- }
-
- OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel:
- OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
- break;
- case BasicBlockLevel:
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
- break;
- case InstructionLevel:
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
- }
- }
- break;
- }
- }
+ EmbeddingLevel Level) const;
/// Get the MIR vocabulary instance
- const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+ const MIRVocabulary *getVocabulary() const;
};
/// Helper structure to hold MIR context
>From 68109b999603d2165427e7ace22830a65e0596b8 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 16:07:39 +0530
Subject: [PATCH 4/5] nit commit - formatting fixup
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 20 +++++++++-----------
1 file changed, 9 insertions(+), 11 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a515797c0c505..1dfcbaedbd16a 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -179,10 +179,10 @@ TripletResult IR2VecTool::generateTriplets(const Function &F) const {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
- LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
+ LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
+ << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
+ << '\t' << "Next\n");
}
// Add "Type" relationship
@@ -368,8 +368,7 @@ bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
const MachineRegisterInfo &MRI = MF->getRegInfo();
- auto VocabOrErr =
- MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
if (!VocabOrErr) {
WithColor::error(errs(), ToolName)
<< "Failed to create dummy vocabulary - "
@@ -409,9 +408,8 @@ TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
- LLVM_DEBUG(dbgs()
- << Vocab->getStringKey(PrevOpcode) << '\t'
- << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+ LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
+ << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
}
// Add "Arg" relationships for operands
@@ -422,8 +420,8 @@ TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
- dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
- << '\t' << "Arg" << ArgIndex << '\n';
+ dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
+ << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
>From c82d055d7d8097181eab853495ecbe7e7663d9ef Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 19:20:22 +0530
Subject: [PATCH 5/5] Work Commit - moving constructores, and getVocabulary
function to header file:
---
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 6 ------
llvm/tools/llvm-ir2vec/llvm-ir2vec.h | 6 +++---
2 files changed, 3 insertions(+), 9 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1dfcbaedbd16a..6b70e09518fa7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -147,8 +147,6 @@ static cl::opt<EmbeddingLevel>
namespace ir2vec {
-IR2VecTool::IR2VecTool(Module &M) : M(M) {}
-
bool IR2VecTool::initializeVocabulary() {
// Register and run the IR2Vec vocabulary analysis
// The vocabulary file path is specified via --ir2vec-vocab-path global
@@ -343,8 +341,6 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
-MIR2VecTool::MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
bool MIR2VecTool::initializeVocabulary(const Module &M) {
MIR2VecVocabProvider Provider(MMI);
auto VocabOrErr = Provider.getVocabulary(M);
@@ -554,8 +550,6 @@ void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
}
}
-const MIRVocabulary *MIR2VecTool::getVocabulary() const { return Vocab.get(); }
-
/// Setup MIR context from input file
Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
SMDiagnostic Err;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index 9bcce9741c917..566c362edbd22 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -90,7 +90,7 @@ class IR2VecTool {
const Vocabulary *Vocab = nullptr;
public:
- explicit IR2VecTool(Module &M);
+ explicit IR2VecTool(Module &M) : M(M) {}
/// Initialize the IR2Vec vocabulary analysis
bool initializeVocabulary();
@@ -140,7 +140,7 @@ class MIR2VecTool {
std::unique_ptr<MIRVocabulary> Vocab;
public:
- explicit MIR2VecTool(MachineModuleInfo &MMI);
+ explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
/// Initialize MIR2Vec vocabulary from file (for embeddings generation)
bool initializeVocabulary(const Module &M);
@@ -183,7 +183,7 @@ class MIR2VecTool {
EmbeddingLevel Level) const;
/// Get the MIR vocabulary instance
- const MIRVocabulary *getVocabulary() const;
+ const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
};
/// Helper structure to hold MIR context
More information about the llvm-commits
mailing list