[llvm] [NFC][llvm-ir2vec] llvm_ir2vec.cpp breakup to extract a reusable header for IR2VecTool, and MIR2VecTool classes (PR #172304)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 18 05:51:24 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/172304

>From 7da07971b66566f0be70b5b214ee0d373b6dc11a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 15 Dec 2025 14:58:51 +0530
Subject: [PATCH 1/5] Breaking up llvm-ir2vec.cpp , extracting out tool classes
 to prepare a common importable module

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 679 +++++--------------------
 llvm/tools/llvm-ir2vec/llvm-ir2vec.h   | 534 +++++++++++++++++++
 2 files changed, 664 insertions(+), 549 deletions(-)
 create mode 100644 llvm/tools/llvm-ir2vec/llvm-ir2vec.h

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 9ab12e36718cd..8b52c385ff524 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,6 +54,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm-ir2vec.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"
@@ -90,8 +91,6 @@
 
 namespace llvm {
 
-static const char *ToolName = "llvm-ir2vec";
-
 // Common option category for options shared between IR2Vec and MIR2Vec
 static cl::OptionCategory CommonCategory("Common Options",
                                          "Options applicable to both IR2Vec "
@@ -135,12 +134,6 @@ static cl::opt<std::string>
                  cl::value_desc("name"), cl::Optional, cl::init(""),
                  cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
 
-enum EmbeddingLevel {
-  InstructionLevel, // Generate instruction-level embeddings
-  BasicBlockLevel,  // Generate basic block-level embeddings
-  FunctionLevel     // Generate function-level embeddings
-};
-
 static cl::opt<EmbeddingLevel>
     Level("level", cl::desc("Embedding generation level:"),
           cl::values(clEnumValN(InstructionLevel, "inst",
@@ -152,219 +145,8 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
-/// Represents a single knowledge graph triplet (Head, Relation, Tail)
-/// where indices reference entities in an EntityList
-struct Triplet {
-  unsigned Head = 0;     ///< Index of the head entity in the entity list
-  unsigned Tail = 0;     ///< Index of the tail entity in the entity list
-  unsigned Relation = 0; ///< Relation type (see RelationType enum)
-};
-
-/// Result structure containing all generated triplets and metadata
-struct TripletResult {
-  unsigned MaxRelation =
-      0; ///< Highest relation index used (for ArgRelation + N)
-  std::vector<Triplet> Triplets; ///< Collection of all generated triplets
-};
-
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
 namespace ir2vec {
 
-/// Relation types for triplet generation
-enum RelationType {
-  TypeRelation = 0, ///< Instruction to type relationship
-  NextRelation = 1, ///< Sequential instruction relationship
-  ArgRelation = 2   ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for collecting IR triplets and generating embeddings
-class IR2VecTool {
-private:
-  Module &M;
-  ModuleAnalysisManager MAM;
-  const Vocabulary *Vocab = nullptr;
-
-public:
-  explicit IR2VecTool(Module &M) : M(M) {}
-
-  /// Initialize the IR2Vec vocabulary analysis
-  bool initializeVocabulary() {
-    // Register and run the IR2Vec vocabulary analysis
-    // The vocabulary file path is specified via --ir2vec-vocab-path global
-    // option
-    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
-    MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
-    // This will throw an error if vocab is not found or invalid
-    Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
-    return Vocab->isValid();
-  }
-
-  /// Generate triplets for a single function
-  /// Returns a TripletResult with:
-  ///   - Triplets: vector of all (subject, object, relation) tuples
-  ///   - MaxRelation: highest Arg relation ID used, or NextRelation if none
-  TripletResult generateTriplets(const Function &F) const {
-    if (F.isDeclaration())
-      return {};
-
-    TripletResult Result;
-    Result.MaxRelation = 0;
-
-    unsigned MaxRelation = NextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-
-    for (const BasicBlock &BB : F) {
-      for (const auto &I : BB.instructionsWithoutDebug()) {
-        unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
-        unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
-          LLVM_DEBUG(dbgs()
-                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                     << "Next\n");
-        }
-
-        // Add "Type" relationship
-        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
-        LLVM_DEBUG(
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                   << '\t' << "Type\n");
-
-        // Add "Arg" relationships
-        unsigned ArgIndex = 0;
-        for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getIndex(*U.get());
-          unsigned RelationID = ArgRelation + ArgIndex;
-          Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
-          LLVM_DEBUG({
-            StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
-                Vocabulary::getOperandKind(U.get()));
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
-        // Only update MaxRelation if there were operands
-        if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-        PrevOpcode = Opcode;
-        HasPrevOpcode = true;
-      }
-    }
-
-    Result.MaxRelation = MaxRelation;
-    return Result;
-  }
-
-  /// Get triplets for the entire module
-  TripletResult generateTriplets() const {
-    TripletResult Result;
-    Result.MaxRelation = NextRelation;
-
-    for (const Function &F : M.getFunctionDefs()) {
-      TripletResult FuncResult = generateTriplets(F);
-      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                             FuncResult.Triplets.end());
-    }
-
-    return Result;
-  }
-
-  /// Collect triplets for the module and dump output to stream
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void writeTripletsToStream(raw_ostream &OS) const {
-    auto Result = generateTriplets();
-    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets)
-      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-  }
-
-  /// Generate entity mappings for the entire vocabulary
-  /// Returns EntityList containing all entity strings
-  static EntityList collectEntityMappings() {
-    auto EntityLen = Vocabulary::getCanonicalSize();
-    EntityList Result;
-    for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      Result.push_back(Vocabulary::getStringKey(EntityID).str());
-    return Result;
-  }
-
-  /// Dump entity ID to string mappings
-  static void writeEntitiesToStream(raw_ostream &OS) {
-    auto Entities = collectEntityMappings();
-    OS << Entities.size() << "\n";
-    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-      OS << Entities[EntityID] << '\t' << EntityID << '\n';
-  }
-
-  /// Generate embeddings for the entire module
-  void writeEmbeddingsToStream(raw_ostream &OS) const {
-    if (!Vocab->isValid()) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M.getFunctionDefs())
-      writeEmbeddingsToStream(F, OS);
-  }
-
-  /// Generate embeddings for a single function
-  void writeEmbeddingsToStream(const Function &F, raw_ostream &OS) const {
-    if (!Vocab || !Vocab->isValid()) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-      return;
-    }
-    if (F.isDeclaration()) {
-      OS << "Function " << F.getName() << " is a declaration, skipping.\n";
-      return;
-    }
-
-    // Create embedder for this function
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for function " << F.getName() << "\n";
-      return;
-    }
-
-    OS << "Function: " << F.getName() << "\n";
-
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      Emb->getFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      for (const BasicBlock &BB : F) {
-        OS << BB.getName() << ":";
-        Emb->getBBVector(BB).print(OS);
-      }
-      break;
-    }
-    case InstructionLevel: {
-      for (const Instruction &I : instructions(F)) {
-        OS << I;
-        Emb->getInstVector(I).print(OS);
-      }
-      break;
-    }
-    }
-  }
-};
-
 /// Process the module and generate output based on selected subcommand
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
@@ -379,14 +161,14 @@ Error processModule(Module &M, raw_ostream &OS) {
     if (!FunctionName.empty()) {
       // Process single function
       if (const Function *F = M.getFunction(FunctionName))
-        Tool.writeEmbeddingsToStream(*F, OS);
+        Tool.writeEmbeddingsToStream(*F, OS, Level);
       else
         return createStringError(errc::invalid_argument,
                                  "Function '%s' not found",
                                  FunctionName.c_str());
     } else {
       // Process all functions
-      Tool.writeEmbeddingsToStream(OS);
+      Tool.writeEmbeddingsToStream(OS, Level);
     }
   } else {
     // Both triplets and entities use triplet generation
@@ -398,257 +180,151 @@ Error processModule(Module &M, raw_ostream &OS) {
 
 namespace mir2vec {
 
-/// Relation types for MIR2Vec triplet generation
-enum MIRRelationType {
-  MIRNextRelation = 0, ///< Sequential instruction relationship
-  MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
-};
+/// Setup MIR context from input file
+Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
+  SMDiagnostic Err;
 
-/// Helper class for MIR2Vec embedding generation
-class MIR2VecTool {
-private:
-  MachineModuleInfo &MMI;
-  std::unique_ptr<MIRVocabulary> Vocab;
-
-public:
-  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
-  /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
-  bool initializeVocabulary(const Module &M) {
-    MIR2VecVocabProvider Provider(MMI);
-    auto VocabOrErr = Provider.getVocabulary(M);
-    if (!VocabOrErr) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to load MIR2Vec vocabulary - "
-          << toString(VocabOrErr.takeError()) << "\n";
-      return false;
-    }
-    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-    return true;
+  auto MIR = createMIRParserFromFile(InputFile, Err, Ctx.Context);
+  if (!MIR) {
+    Err.print(ToolName, errs());
+    return createStringError(errc::invalid_argument,
+                             "Failed to parse MIR file");
   }
 
-  /// Initialize vocabulary with layout information only.
-  /// This creates a minimal vocabulary with correct layout but no actual
-  /// embeddings. Sufficient for generating training data and entity mappings.
-  ///
-  /// Note: Requires target-specific information from the first machine function
-  /// to determine the vocabulary layout (number of opcodes, register classes).
-  ///
-  /// FIXME: Use --target option to get target info directly, avoiding the need
-  /// to parse machine functions for pre-training operations.
-  bool initializeVocabularyForLayout(const Module &M) {
-    for (const Function &F : M.getFunctionDefs()) {
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF)
-        continue;
-
-      const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
-      const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
-      const MachineRegisterInfo &MRI = MF->getRegInfo();
-
-      auto VocabOrErr =
-          MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
-      if (!VocabOrErr) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to create dummy vocabulary - "
-            << toString(VocabOrErr.takeError()) << "\n";
-        return false;
-      }
-      Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-      return true;
-    }
-
-    WithColor::error(errs(), ToolName)
-        << "No machine functions found to initialize vocabulary\n";
-    return false;
-  }
-
-  /// Get triplets for a single machine function
-  /// Returns TripletResult containing MaxRelation and vector of Triplets
-  TripletResult generateTriplets(const MachineFunction &MF) const {
-    TripletResult Result;
-    Result.MaxRelation = MIRNextRelation;
-
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return Result;
-    }
-
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-    for (const MachineBasicBlock &MBB : MF) {
-      for (const MachineInstr &MI : MBB) {
-        // Skip debug instructions
-        if (MI.isDebugInstr())
-          continue;
-
-        // Get opcode entity ID
-        unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
-          LLVM_DEBUG(dbgs()
-                     << Vocab->getStringKey(PrevOpcode) << '\t'
-                     << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
-        }
-
-        // Add "Arg" relationships for operands
-        unsigned ArgIndex = 0;
-        for (const MachineOperand &MO : MI.operands()) {
-          auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
-          unsigned RelationID = MIRArgRelation + ArgIndex;
-          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
-          LLVM_DEBUG({
-            std::string OperandStr = Vocab->getStringKey(OperandID);
-            dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
-                   << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
+  auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
+                           StringRef OldDLStr) -> std::optional<std::string> {
+    std::string IRTargetTriple = DataLayoutTargetTriple.str();
+    Triple TheTriple = Triple(IRTargetTriple);
+    if (TheTriple.getTriple().empty())
+      TheTriple.setTriple(sys::getDefaultTargetTriple());
 
-        // Update MaxRelation if there were operands
-        if (ArgIndex > 0)
-          Result.MaxRelation =
-              std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
-        PrevOpcode = OpcodeID;
-        HasPrevOpcode = true;
-      }
-    }
-
-    return Result;
-  }
-
-  /// Get triplets for the entire module
-  /// Returns TripletResult containing aggregated MaxRelation and all Triplets
-  TripletResult generateTriplets(const Module &M) const {
-    TripletResult Result;
-    Result.MaxRelation = MIRNextRelation;
-
-    for (const Function &F : M.getFunctionDefs()) {
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      TripletResult FuncResult = generateTriplets(*MF);
-      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                             FuncResult.Triplets.end());
+    auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
+    if (!TMOrErr) {
+      Err.print(ToolName, errs());
+      exit(1); // Match original behavior
     }
+    Ctx.TM = std::move(*TMOrErr);
+    return Ctx.TM->createDataLayout().getStringRepresentation();
+  };
 
-    return Result;
+  Ctx.M = MIR->parseIRModule(SetDataLayout);
+  if (!Ctx.M) {
+    Err.print(ToolName, errs());
+    return createStringError(errc::invalid_argument,
+                             "Failed to parse IR module");
   }
 
-  /// Collect triplets for the module and write to output stream
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
-    auto Result = generateTriplets(M);
-    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets)
-      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  Ctx.MMI = std::make_unique<MachineModuleInfo>(Ctx.TM.get());
+  if (!Ctx.MMI || MIR->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) {
+    Err.print(ToolName, errs());
+    return createStringError(errc::invalid_argument,
+                             "Failed to parse machine functions");
   }
 
-  /// Generate entity mappings for the entire vocabulary
-  EntityList collectEntityMappings() const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary must be initialized for entity mappings.\n";
-      return {};
-    }
-
-    const unsigned EntityCount = Vocab->getCanonicalSize();
-    EntityList Result;
-    for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      Result.push_back(Vocab->getStringKey(EntityID));
+  return Error::success();
+}
 
-    return Result;
-  }
+/// Generic vocabulary initialization and processing
+template <typename ProcessFunc>
+Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
+                            bool useLayoutVocab, ProcessFunc processFn) {
+  MIR2VecTool Tool(*Ctx.MMI);
 
-  /// Generate entity mappings and write to output stream
-  void writeEntitiesToStream(raw_ostream &OS) const {
-    auto Entities = collectEntityMappings();
-    if (Entities.empty())
-      return;
+  // Initialize appropriate vocabulary type
+  bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(*Ctx.M)
+                                : Tool.initializeVocabulary(*Ctx.M);
 
-    OS << Entities.size() << "\n";
-    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-      OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  if (!success) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to initialize MIR2Vec vocabulary"
+        << (useLayoutVocab ? " for layout" : "") << ".\n";
+    return createStringError(errc::invalid_argument,
+                             "Vocabulary initialization failed");
   }
 
-  /// Generate embeddings for all machine functions in the module
-  void writeEmbeddingsToStream(const Module &M, raw_ostream &OS) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M.getFunctionDefs()) {
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
+  assert(Tool.getVocabulary() &&
+         "MIR2Vec vocabulary should be initialized at this point");
 
-      writeEmbeddingsToStream(*MF, OS);
-    }
-  }
+  LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+                    << "Vocabulary dimension: "
+                    << Tool.getVocabulary()->getDimension() << "\n"
+                    << "Vocabulary size: "
+                    << Tool.getVocabulary()->getCanonicalSize() << "\n");
 
-  /// Generate embeddings for a specific machine function
-  void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
+  // Execute the specific processing logic
+  return processFn(Tool);
+}
 
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
-      return;
-    }
+/// Process module for triplet generation
+Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
+  return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+                               [&](MIR2VecTool &Tool) -> Error {
+                                 Tool.writeTripletsToStream(*Ctx.M, OS);
+                                 return Error::success();
+                               });
+}
 
-    OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+/// Process module for entity generation
+Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
+  return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+                               [&](MIR2VecTool &Tool) -> Error {
+                                 Tool.writeEntitiesToStream(OS);
+                                 return Error::success();
+                               });
+}
 
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      OS << "Function vector: ";
-      Emb->getMFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      OS << "Basic block vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        OS << "MBB " << MBB.getName() << ": ";
-        Emb->getMBBVector(MBB).print(OS);
-      }
-      break;
-    }
-    case InstructionLevel: {
-      OS << "Instruction vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        for (const MachineInstr &MI : MBB) {
-          OS << MI << " -> ";
-          Emb->getMInstVector(MI).print(OS);
+/// Process module for embedding generation
+Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
+  return processWithVocabulary(
+      Ctx, OS, /*useLayoutVocab=*/false, [&](MIR2VecTool &Tool) -> Error {
+        if (!FunctionName.empty()) {
+          // Process single function
+          Function *F = Ctx.M->getFunction(FunctionName);
+          if (!F) {
+            WithColor::error(errs(), ToolName)
+                << "Function '" << FunctionName << "' not found\n";
+            return createStringError(errc::invalid_argument,
+                                     "Function not found");
+          }
+
+          MachineFunction *MF = Ctx.MMI->getMachineFunction(*F);
+          if (!MF) {
+            WithColor::error(errs(), ToolName)
+                << "No MachineFunction for " << FunctionName << "\n";
+            return createStringError(errc::invalid_argument,
+                                     "No MachineFunction");
+          }
+
+          Tool.writeEmbeddingsToStream(*MF, OS, Level);
+        } else {
+          // Process all functions
+          Tool.writeEmbeddingsToStream(*Ctx.M, OS, Level);
         }
-      }
-      break;
-    }
-    }
-  }
+        return Error::success();
+      });
+}
 
-  /// Get the MIR vocabulary instance
-  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
-};
+/// Main entry point for MIR processing
+Error processModule(const std::string &InputFile, raw_ostream &OS) {
+  MIRContext Ctx;
+
+  // Setup MIR context (parse file, setup target machine, etc.)
+  if (auto Err = setupMIRContext(InputFile, Ctx))
+    return Err;
+
+  // Process based on subcommand
+  if (TripletsSubCmd)
+    return processModuleForTriplets(Ctx, OS);
+  else if (EntitiesSubCmd)
+    return processModuleForEntities(Ctx, OS);
+  else if (EmbeddingsSubCmd)
+    return processModuleForEmbeddings(Ctx, OS);
+  else {
+    WithColor::error(errs(), ToolName)
+        << "Please specify a subcommand: triplets, entities, or embeddings\n";
+    return createStringError(errc::invalid_argument, "No subcommand specified");
+  }
+}
 
 } // namespace mir2vec
 
@@ -712,105 +388,10 @@ int main(int argc, char **argv) {
     InitializeAllAsmPrinters();
     static codegen::RegisterCodeGenFlags CGF;
 
-    // Parse MIR input file
-    SMDiagnostic Err;
-    LLVMContext Context;
-    std::unique_ptr<TargetMachine> TM;
-
-    auto MIR = createMIRParserFromFile(InputFilename, Err, Context);
-    if (!MIR) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
-                             StringRef OldDLStr) -> std::optional<std::string> {
-      std::string IRTargetTriple = DataLayoutTargetTriple.str();
-      Triple TheTriple = Triple(IRTargetTriple);
-      if (TheTriple.getTriple().empty())
-        TheTriple.setTriple(sys::getDefaultTargetTriple());
-      auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
-      if (!TMOrErr) {
-        Err.print(ToolName, errs());
-        exit(1);
-      }
-      TM = std::move(*TMOrErr);
-      return TM->createDataLayout().getStringRepresentation();
-    };
-
-    std::unique_ptr<Module> M = MIR->parseIRModule(SetDataLayout);
-    if (!M) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    // Parse machine functions
-    auto MMI = std::make_unique<MachineModuleInfo>(TM.get());
-    if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    // Create MIR2Vec tool
-    MIR2VecTool Tool(*MMI);
-
-    // Initialize vocabulary. For triplet/entity generation, only layout is
-    // needed For embedding generation, the full vocabulary is needed.
-    //
-    // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires
-    // target-specific information for generating the vocabulary layout. So, we
-    // always initialize the vocabulary in this case.
-    if (TripletsSubCmd || EntitiesSubCmd) {
-      if (!Tool.initializeVocabularyForLayout(*M)) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to initialize MIR2Vec vocabulary for layout.\n";
-        return 1;
-      }
-    } else {
-      if (!Tool.initializeVocabulary(*M)) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to initialize MIR2Vec vocabulary.\n";
-        return 1;
-      }
-    }
-    assert(Tool.getVocabulary() &&
-           "MIR2Vec vocabulary should be initialized at this point");
-    LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
-                      << "Vocabulary dimension: "
-                      << Tool.getVocabulary()->getDimension() << "\n"
-                      << "Vocabulary size: "
-                      << Tool.getVocabulary()->getCanonicalSize() << "\n");
-
-    // Handle subcommands
-    if (TripletsSubCmd) {
-      Tool.writeTripletsToStream(*M, OS);
-    } else if (EntitiesSubCmd) {
-      Tool.writeEntitiesToStream(OS);
-    } else if (EmbeddingsSubCmd) {
-      if (!FunctionName.empty()) {
-        // Process single function
-        Function *F = M->getFunction(FunctionName);
-        if (!F) {
-          WithColor::error(errs(), ToolName)
-              << "Function '" << FunctionName << "' not found\n";
-          return 1;
-        }
-
-        MachineFunction *MF = MMI->getMachineFunction(*F);
-        if (!MF) {
-          WithColor::error(errs(), ToolName)
-              << "No MachineFunction for " << FunctionName << "\n";
-          return 1;
-        }
-
-        Tool.writeEmbeddingsToStream(*MF, OS);
-      } else {
-        // Process all functions
-        Tool.writeEmbeddingsToStream(*M, OS);
-      }
-    } else {
-      WithColor::error(errs(), ToolName)
-          << "Please specify a subcommand: triplets, entities, or embeddings\n";
+    if (Error Err = mir2vec::processModule(InputFilename, OS)) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
+        WithColor::error(errs(), ToolName) << EIB.message() << "\n";
+      });
       return 1;
     }
 
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
new file mode 100644
index 0000000000000..56fd834d380a8
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -0,0 +1,534 @@
+//===- llvm-ir2vec.h - IR2Vec/MIR2Vec Tool Classes ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the IR2VecTool and MIR2VecTool class definitions and
+/// implementations for the llvm-ir2vec embedding generation tool.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetMachine.h"
+#include <memory>
+#include <string>
+#include <vector>
+
+#define DEBUG_TYPE "ir2vec"
+
+namespace llvm {
+
+/// Tool name for error reporting
+static const char *ToolName = "llvm-ir2vec";
+
+/// Specifies the granularity at which embeddings are generated.
+enum EmbeddingLevel {
+  InstructionLevel, // Generate instruction-level embeddings
+  BasicBlockLevel,  // Generate basic block-level embeddings
+  FunctionLevel     // Generate function-level embeddings
+};
+
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
+struct Triplet {
+  unsigned Head = 0;     ///< Index of the head entity in the entity list
+  unsigned Tail = 0;     ///< Index of the tail entity in the entity list
+  unsigned Relation = 0; ///< Relation type (see RelationType enum)
+};
+
+/// Result structure containing all generated triplets and metadata
+struct TripletResult {
+  unsigned MaxRelation =
+      0; ///< Highest relation index used (for ArgRelation + N)
+  std::vector<Triplet> Triplets; ///< Collection of all generated triplets
+};
+
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+namespace ir2vec {
+
+/// Relation types for triplet generation
+enum RelationType {
+  TypeRelation = 0, ///< Instruction to type relationship
+  NextRelation = 1, ///< Sequential instruction relationship
+  ArgRelation = 2   ///< Instruction to operand relationship (ArgRelation + N)
+};
+
+/// Helper class for collecting IR triplets and generating embeddings
+class IR2VecTool {
+private:
+  Module &M;
+  ModuleAnalysisManager MAM;
+  const Vocabulary *Vocab = nullptr;
+
+public:
+  explicit IR2VecTool(Module &M) : M(M) {}
+
+  /// Initialize the IR2Vec vocabulary analysis
+  bool initializeVocabulary() {
+    // Register and run the IR2Vec vocabulary analysis
+    // The vocabulary file path is specified via --ir2vec-vocab-path global
+    // option
+    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+    // This will throw an error if vocab is not found or invalid
+    Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+    return Vocab->isValid();
+  }
+
+  /// Generate triplets for a single function
+  /// Returns a TripletResult with:
+  ///   - Triplets: vector of all (subject, object, relation) tuples
+  ///   - MaxRelation: highest Arg relation ID used, or NextRelation if none
+  TripletResult generateTriplets(const Function &F) const {
+    if (F.isDeclaration())
+      return {};
+
+    TripletResult Result;
+    Result.MaxRelation = 0;
+
+    unsigned MaxRelation = NextRelation;
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
+
+    for (const BasicBlock &BB : F) {
+      for (const auto &I : BB.instructionsWithoutDebug()) {
+        unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+        unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+        // Add "Next" relationship with previous instruction
+        if (HasPrevOpcode) {
+          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+          LLVM_DEBUG(dbgs()
+                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                     << "Next\n");
+        }
+
+        // Add "Type" relationship
+        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+        LLVM_DEBUG(
+            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                   << '\t' << "Type\n");
+
+        // Add "Arg" relationships
+        unsigned ArgIndex = 0;
+        for (const Use &U : I.operands()) {
+          unsigned OperandID = Vocabulary::getIndex(*U.get());
+          unsigned RelationID = ArgRelation + ArgIndex;
+          Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+          LLVM_DEBUG({
+            StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+                Vocabulary::getOperandKind(U.get()));
+            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+          });
+
+          ++ArgIndex;
+        }
+        // Only update MaxRelation if there were operands
+        if (ArgIndex > 0)
+          MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+        PrevOpcode = Opcode;
+        HasPrevOpcode = true;
+      }
+    }
+
+    Result.MaxRelation = MaxRelation;
+    return Result;
+  }
+
+  /// Get triplets for the entire module
+  TripletResult generateTriplets() const {
+    TripletResult Result;
+    Result.MaxRelation = NextRelation;
+
+    for (const Function &F : M.getFunctionDefs()) {
+      TripletResult FuncResult = generateTriplets(F);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Collect triplets for the module and dump output to stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void writeTripletsToStream(raw_ostream &OS) const {
+    auto Result = generateTriplets();
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets)
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  }
+
+  /// Generate entity mappings for the entire vocabulary
+  /// Returns EntityList containing all entity strings
+  static EntityList collectEntityMappings() {
+    auto EntityLen = Vocabulary::getCanonicalSize();
+    EntityList Result;
+    for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+      Result.push_back(Vocabulary::getStringKey(EntityID).str());
+    return Result;
+  }
+
+  /// Dump entity ID to string mappings
+  static void writeEntitiesToStream(raw_ostream &OS) {
+    auto Entities = collectEntityMappings();
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  }
+
+  /// Generate embeddings for the entire module
+  void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const {
+    if (!Vocab->isValid()) {
+      WithColor::error(errs(), ToolName)
+          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+      return;
+    }
+
+    for (const Function &F : M.getFunctionDefs())
+      writeEmbeddingsToStream(F, OS, Level);
+  }
+
+  /// Generate embeddings for a single function
+  void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+                               EmbeddingLevel Level) const {
+    if (!Vocab || !Vocab->isValid()) {
+      WithColor::error(errs(), ToolName)
+          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+      return;
+    }
+    if (F.isDeclaration()) {
+      OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+      return;
+    }
+
+    // Create embedder for this function
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for function " << F.getName() << "\n";
+      return;
+    }
+
+    OS << "Function: " << F.getName() << "\n";
+
+    // Generate embeddings based on the specified level
+    switch (Level) {
+    case FunctionLevel:
+      Emb->getFunctionVector().print(OS);
+      break;
+    case BasicBlockLevel:
+      for (const BasicBlock &BB : F) {
+        OS << BB.getName() << ":";
+        Emb->getBBVector(BB).print(OS);
+      }
+      break;
+    case InstructionLevel:
+      for (const Instruction &I : instructions(F)) {
+        OS << I;
+        Emb->getInstVector(I).print(OS);
+      }
+      break;
+    }
+  }
+};
+
+} // namespace ir2vec
+
+namespace mir2vec {
+
+/// Relation types for MIR2Vec triplet generation
+enum MIRRelationType {
+  MIRNextRelation = 0, ///< Sequential instruction relationship
+  MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
+};
+
+/// Helper class for MIR2Vec embedding generation
+class MIR2VecTool {
+private:
+  MachineModuleInfo &MMI;
+  std::unique_ptr<MIRVocabulary> Vocab;
+
+public:
+  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+  /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
+  bool initializeVocabulary(const Module &M) {
+    MIR2VecVocabProvider Provider(MMI);
+    auto VocabOrErr = Provider.getVocabulary(M);
+    if (!VocabOrErr) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to load MIR2Vec vocabulary - "
+          << toString(VocabOrErr.takeError()) << "\n";
+      return false;
+    }
+    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+    return true;
+  }
+
+  /// Initialize vocabulary with layout information only.
+  /// This creates a minimal vocabulary with correct layout but no actual
+  /// embeddings. Sufficient for generating training data and entity mappings.
+  ///
+  /// Note: Requires target-specific information from the first machine function
+  /// to determine the vocabulary layout (number of opcodes, register classes).
+  ///
+  /// FIXME: Use --target option to get target info directly, avoiding the need
+  /// to parse machine functions for pre-training operations.
+  bool initializeVocabularyForLayout(const Module &M) {
+    for (const Function &F : M.getFunctionDefs()) {
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF)
+        continue;
+
+      const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+      const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+      const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+      auto VocabOrErr =
+          MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+      if (!VocabOrErr) {
+        WithColor::error(errs(), ToolName)
+            << "Failed to create dummy vocabulary - "
+            << toString(VocabOrErr.takeError()) << "\n";
+        return false;
+      }
+      Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+      return true;
+    }
+
+    WithColor::error(errs(), ToolName)
+        << "No machine functions found to initialize vocabulary\n";
+    return false;
+  }
+
+  /// Get triplets for a single machine function
+  /// Returns TripletResult containing MaxRelation and vector of Triplets
+  TripletResult generateTriplets(const MachineFunction &MF) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
+
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName)
+          << "MIR Vocabulary must be initialized for triplet generation.\n";
+      return Result;
+    }
+
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        // Skip debug instructions
+        if (MI.isDebugInstr())
+          continue;
+
+        // Get opcode entity ID
+        unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+        // Add "Next" relationship with previous instruction
+        if (HasPrevOpcode) {
+          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+          LLVM_DEBUG(dbgs()
+                     << Vocab->getStringKey(PrevOpcode) << '\t'
+                     << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+        }
+
+        // Add "Arg" relationships for operands
+        unsigned ArgIndex = 0;
+        for (const MachineOperand &MO : MI.operands()) {
+          auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+          unsigned RelationID = MIRArgRelation + ArgIndex;
+          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+          LLVM_DEBUG({
+            std::string OperandStr = Vocab->getStringKey(OperandID);
+            dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
+                   << '\t' << "Arg" << ArgIndex << '\n';
+          });
+
+          ++ArgIndex;
+        }
+
+        // Update MaxRelation if there were operands
+        if (ArgIndex > 0)
+          Result.MaxRelation =
+              std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+        PrevOpcode = OpcodeID;
+        HasPrevOpcode = true;
+      }
+    }
+
+    return Result;
+  }
+
+  /// Get triplets for the entire module
+  /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+  TripletResult generateTriplets(const Module &M) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
+
+    for (const Function &F : M.getFunctionDefs()) {
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        WithColor::warning(errs(), ToolName)
+            << "No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      TripletResult FuncResult = generateTriplets(*MF);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Collect triplets for the module and write to output stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
+    auto Result = generateTriplets(M);
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets)
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  }
+
+  /// Generate entity mappings for the entire vocabulary
+  EntityList collectEntityMappings() const {
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName)
+          << "Vocabulary must be initialized for entity mappings.\n";
+      return {};
+    }
+
+    const unsigned EntityCount = Vocab->getCanonicalSize();
+    EntityList Result;
+    for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+      Result.push_back(Vocab->getStringKey(EntityID));
+
+    return Result;
+  }
+
+  /// Generate entity mappings and write to output stream
+  void writeEntitiesToStream(raw_ostream &OS) const {
+    auto Entities = collectEntityMappings();
+    if (Entities.empty())
+      return;
+
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  }
+
+  /// Generate embeddings for all machine functions in the module
+  void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+                               EmbeddingLevel Level) const {
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return;
+    }
+
+    for (const Function &F : M.getFunctionDefs()) {
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        WithColor::warning(errs(), ToolName)
+            << "No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      writeEmbeddingsToStream(*MF, OS, Level);
+    }
+  }
+
+  /// Generate embeddings for a specific machine function
+  void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+                               EmbeddingLevel Level) const {
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << MF.getName() << "\n";
+      return;
+    }
+
+    OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+    // Generate embeddings based on the specified level
+    switch (Level) {
+    case FunctionLevel:
+      OS << "Function vector: ";
+      Emb->getMFunctionVector().print(OS);
+      break;
+    case BasicBlockLevel:
+      OS << "Basic block vectors:\n";
+      for (const MachineBasicBlock &MBB : MF) {
+        OS << "MBB " << MBB.getName() << ": ";
+        Emb->getMBBVector(MBB).print(OS);
+      }
+      break;
+    case InstructionLevel:
+      OS << "Instruction vectors:\n";
+      for (const MachineBasicBlock &MBB : MF) {
+        for (const MachineInstr &MI : MBB) {
+          OS << MI << " -> ";
+          Emb->getMInstVector(MI).print(OS);
+        }
+      }
+      break;
+    }
+  }
+
+  /// Get the MIR vocabulary instance
+  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+};
+
+/// Helper structure to hold MIR context
+struct MIRContext {
+  LLVMContext Context; // CRITICAL: Must be first for proper destruction order
+  std::unique_ptr<Module> M;
+  std::unique_ptr<MachineModuleInfo> MMI;
+  std::unique_ptr<TargetMachine> TM;
+};
+
+} // namespace mir2vec
+
+} // namespace llvm
+
+#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
\ No newline at end of file

>From 41e4e13eafc0621a92ed053371aa2bb7112396d0 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 12:47:49 +0530
Subject: [PATCH 2/5] Nit commit - header guard naming convention compliance

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index 56fd834d380a8..cb4572996b1f4 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -12,8 +12,8 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
-#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
+#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
@@ -531,4 +531,4 @@ struct MIRContext {
 
 } // namespace llvm
 
-#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H
\ No newline at end of file
+#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_IR2VEC_H
\ No newline at end of file

>From 0978b4c957717df4171135601b61392af3784880 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 16:01:18 +0530
Subject: [PATCH 3/5] Work Commit - Moved function definitions out of header
 file

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 378 +++++++++++++++++++++++++
 llvm/tools/llvm-ir2vec/llvm-ir2vec.h   | 377 ++----------------------
 2 files changed, 400 insertions(+), 355 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 8b52c385ff524..a515797c0c505 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -147,6 +147,169 @@ static cl::opt<EmbeddingLevel>
 
 namespace ir2vec {
 
+IR2VecTool::IR2VecTool(Module &M) : M(M) {}
+
+bool IR2VecTool::initializeVocabulary() {
+  // Register and run the IR2Vec vocabulary analysis
+  // The vocabulary file path is specified via --ir2vec-vocab-path global
+  // option
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+  // This will throw an error if vocab is not found or invalid
+  Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+  return Vocab->isValid();
+}
+
+TripletResult IR2VecTool::generateTriplets(const Function &F) const {
+  if (F.isDeclaration())
+    return {};
+
+  TripletResult Result;
+  Result.MaxRelation = 0;
+
+  unsigned MaxRelation = NextRelation;
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+
+  for (const BasicBlock &BB : F) {
+    for (const auto &I : BB.instructionsWithoutDebug()) {
+      unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+      unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+      // Add "Next" relationship with previous instruction
+      if (HasPrevOpcode) {
+        Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+        LLVM_DEBUG(dbgs()
+                   << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                   << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << "Next\n");
+      }
+
+      // Add "Type" relationship
+      Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+      LLVM_DEBUG(
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                 << '\t' << "Type\n");
+
+      // Add "Arg" relationships
+      unsigned ArgIndex = 0;
+      for (const Use &U : I.operands()) {
+        unsigned OperandID = Vocabulary::getIndex(*U.get());
+        unsigned RelationID = ArgRelation + ArgIndex;
+        Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+        LLVM_DEBUG({
+          StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+              Vocabulary::getOperandKind(U.get()));
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+      // Only update MaxRelation if there were operands
+      if (ArgIndex > 0)
+        MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+      PrevOpcode = Opcode;
+      HasPrevOpcode = true;
+    }
+  }
+
+  Result.MaxRelation = MaxRelation;
+  return Result;
+}
+
+TripletResult IR2VecTool::generateTriplets() const {
+  TripletResult Result;
+  Result.MaxRelation = NextRelation;
+
+  for (const Function &F : M.getFunctionDefs()) {
+    TripletResult FuncResult = generateTriplets(F);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                           FuncResult.Triplets.end());
+  }
+
+  return Result;
+}
+
+void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
+  auto Result = generateTriplets();
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+  for (const auto &T : Result.Triplets)
+    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList IR2VecTool::collectEntityMappings() {
+  auto EntityLen = Vocabulary::getCanonicalSize();
+  EntityList Result;
+  for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+    Result.push_back(Vocabulary::getStringKey(EntityID).str());
+  return Result;
+}
+
+void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
+  auto Entities = collectEntityMappings();
+  OS << Entities.size() << "\n";
+  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
+                                         EmbeddingLevel Level) const {
+  if (!Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
+
+  for (const Function &F : M.getFunctionDefs())
+    writeEmbeddingsToStream(F, OS, Level);
+}
+
+void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
+                                         EmbeddingLevel Level) const {
+  if (!Vocab || !Vocab->isValid()) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+    return;
+  }
+  if (F.isDeclaration()) {
+    OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+    return;
+  }
+
+  // Create embedder for this function
+  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for function " << F.getName() << "\n";
+    return;
+  }
+
+  OS << "Function: " << F.getName() << "\n";
+
+  // Generate embeddings based on the specified level
+  switch (Level) {
+  case FunctionLevel:
+    Emb->getFunctionVector().print(OS);
+    break;
+  case BasicBlockLevel:
+    for (const BasicBlock &BB : F) {
+      OS << BB.getName() << ":";
+      Emb->getBBVector(BB).print(OS);
+    }
+    break;
+  case InstructionLevel:
+    for (const Instruction &I : instructions(F)) {
+      OS << I;
+      Emb->getInstVector(I).print(OS);
+    }
+    break;
+  }
+}
+
 /// Process the module and generate output based on selected subcommand
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
@@ -180,6 +343,221 @@ Error processModule(Module &M, raw_ostream &OS) {
 
 namespace mir2vec {
 
+MIR2VecTool::MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+bool MIR2VecTool::initializeVocabulary(const Module &M) {
+  MIR2VecVocabProvider Provider(MMI);
+  auto VocabOrErr = Provider.getVocabulary(M);
+  if (!VocabOrErr) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to load MIR2Vec vocabulary - "
+        << toString(VocabOrErr.takeError()) << "\n";
+    return false;
+  }
+  Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+  return true;
+}
+
+bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF)
+      continue;
+
+    const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+    const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+    const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+    auto VocabOrErr =
+        MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+    if (!VocabOrErr) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create dummy vocabulary - "
+          << toString(VocabOrErr.takeError()) << "\n";
+      return false;
+    }
+    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+    return true;
+  }
+
+  WithColor::error(errs(), ToolName)
+      << "No machine functions found to initialize vocabulary\n";
+  return false;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
+  TripletResult Result;
+  Result.MaxRelation = MIRNextRelation;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "MIR Vocabulary must be initialized for triplet generation.\n";
+    return Result;
+  }
+
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions
+      if (MI.isDebugInstr())
+        continue;
+
+      // Get opcode entity ID
+      unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+      // Add "Next" relationship with previous instruction
+      if (HasPrevOpcode) {
+        Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+        LLVM_DEBUG(dbgs()
+                   << Vocab->getStringKey(PrevOpcode) << '\t'
+                   << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+      }
+
+      // Add "Arg" relationships for operands
+      unsigned ArgIndex = 0;
+      for (const MachineOperand &MO : MI.operands()) {
+        auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+        unsigned RelationID = MIRArgRelation + ArgIndex;
+        Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
+        LLVM_DEBUG({
+          std::string OperandStr = Vocab->getStringKey(OperandID);
+          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
+                 << '\t' << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+
+      // Update MaxRelation if there were operands
+      if (ArgIndex > 0)
+        Result.MaxRelation =
+            std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+      PrevOpcode = OpcodeID;
+      HasPrevOpcode = true;
+    }
+  }
+
+  return Result;
+}
+
+TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
+  TripletResult Result;
+  Result.MaxRelation = MIRNextRelation;
+
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    TripletResult FuncResult = generateTriplets(*MF);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+    Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                           FuncResult.Triplets.end());
+  }
+
+  return Result;
+}
+
+void MIR2VecTool::writeTripletsToStream(const Module &M,
+                                        raw_ostream &OS) const {
+  auto Result = generateTriplets(M);
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+  for (const auto &T : Result.Triplets)
+    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+}
+
+EntityList MIR2VecTool::collectEntityMappings() const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary must be initialized for entity mappings.\n";
+    return {};
+  }
+
+  const unsigned EntityCount = Vocab->getCanonicalSize();
+  EntityList Result;
+  for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+    Result.push_back(Vocab->getStringKey(EntityID));
+
+  return Result;
+}
+
+void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
+  auto Entities = collectEntityMappings();
+  if (Entities.empty())
+    return;
+
+  OS << Entities.size() << "\n";
+  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
+                                          EmbeddingLevel Level) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  for (const Function &F : M.getFunctionDefs()) {
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    writeEmbeddingsToStream(*MF, OS, Level);
+  }
+}
+
+void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
+                                          EmbeddingLevel Level) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+  // Generate embeddings based on the specified level
+  switch (Level) {
+  case FunctionLevel:
+    OS << "Function vector: ";
+    Emb->getMFunctionVector().print(OS);
+    break;
+  case BasicBlockLevel:
+    OS << "Basic block vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      OS << "MBB " << MBB.getName() << ": ";
+      Emb->getMBBVector(MBB).print(OS);
+    }
+    break;
+  case InstructionLevel:
+    OS << "Instruction vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        OS << MI << " -> ";
+        Emb->getMInstVector(MI).print(OS);
+      }
+    }
+    break;
+  }
+}
+
+const MIRVocabulary *MIR2VecTool::getVocabulary() const { return Vocab.get(); }
+
 /// Setup MIR context from input file
 Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
   SMDiagnostic Err;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index cb4572996b1f4..9bcce9741c917 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 ///
 /// \file
-/// This file contains the IR2VecTool and MIR2VecTool class definitions and
-/// implementations for the llvm-ir2vec embedding generation tool.
+/// This file contains the IR2VecTool and MIR2VecTool class definitions for
+/// the llvm-ir2vec embedding generation tool.
 ///
 //===----------------------------------------------------------------------===//
 
@@ -90,180 +90,37 @@ class IR2VecTool {
   const Vocabulary *Vocab = nullptr;
 
 public:
-  explicit IR2VecTool(Module &M) : M(M) {}
+  explicit IR2VecTool(Module &M);
 
   /// Initialize the IR2Vec vocabulary analysis
-  bool initializeVocabulary() {
-    // Register and run the IR2Vec vocabulary analysis
-    // The vocabulary file path is specified via --ir2vec-vocab-path global
-    // option
-    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
-    MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
-    // This will throw an error if vocab is not found or invalid
-    Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
-    return Vocab->isValid();
-  }
+  bool initializeVocabulary();
 
   /// Generate triplets for a single function
   /// Returns a TripletResult with:
   ///   - Triplets: vector of all (subject, object, relation) tuples
   ///   - MaxRelation: highest Arg relation ID used, or NextRelation if none
-  TripletResult generateTriplets(const Function &F) const {
-    if (F.isDeclaration())
-      return {};
-
-    TripletResult Result;
-    Result.MaxRelation = 0;
-
-    unsigned MaxRelation = NextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-
-    for (const BasicBlock &BB : F) {
-      for (const auto &I : BB.instructionsWithoutDebug()) {
-        unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
-        unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
-          LLVM_DEBUG(dbgs()
-                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                     << "Next\n");
-        }
-
-        // Add "Type" relationship
-        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
-        LLVM_DEBUG(
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                   << '\t' << "Type\n");
-
-        // Add "Arg" relationships
-        unsigned ArgIndex = 0;
-        for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getIndex(*U.get());
-          unsigned RelationID = ArgRelation + ArgIndex;
-          Result.Triplets.push_back({Opcode, OperandID, RelationID});
-
-          LLVM_DEBUG({
-            StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
-                Vocabulary::getOperandKind(U.get()));
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
-        // Only update MaxRelation if there were operands
-        if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-        PrevOpcode = Opcode;
-        HasPrevOpcode = true;
-      }
-    }
-
-    Result.MaxRelation = MaxRelation;
-    return Result;
-  }
+  TripletResult generateTriplets(const Function &F) const;
 
   /// Get triplets for the entire module
-  TripletResult generateTriplets() const {
-    TripletResult Result;
-    Result.MaxRelation = NextRelation;
-
-    for (const Function &F : M.getFunctionDefs()) {
-      TripletResult FuncResult = generateTriplets(F);
-      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                             FuncResult.Triplets.end());
-    }
-
-    return Result;
-  }
+  TripletResult generateTriplets() const;
 
   /// Collect triplets for the module and dump output to stream
   /// Output format: MAX_RELATION=N header followed by relationships
-  void writeTripletsToStream(raw_ostream &OS) const {
-    auto Result = generateTriplets();
-    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets)
-      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-  }
+  void writeTripletsToStream(raw_ostream &OS) const;
 
   /// Generate entity mappings for the entire vocabulary
   /// Returns EntityList containing all entity strings
-  static EntityList collectEntityMappings() {
-    auto EntityLen = Vocabulary::getCanonicalSize();
-    EntityList Result;
-    for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      Result.push_back(Vocabulary::getStringKey(EntityID).str());
-    return Result;
-  }
+  static EntityList collectEntityMappings();
 
   /// Dump entity ID to string mappings
-  static void writeEntitiesToStream(raw_ostream &OS) {
-    auto Entities = collectEntityMappings();
-    OS << Entities.size() << "\n";
-    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-      OS << Entities[EntityID] << '\t' << EntityID << '\n';
-  }
+  static void writeEntitiesToStream(raw_ostream &OS);
 
   /// Generate embeddings for the entire module
-  void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const {
-    if (!Vocab->isValid()) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M.getFunctionDefs())
-      writeEmbeddingsToStream(F, OS, Level);
-  }
+  void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
 
   /// Generate embeddings for a single function
   void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
-                               EmbeddingLevel Level) const {
-    if (!Vocab || !Vocab->isValid()) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-      return;
-    }
-    if (F.isDeclaration()) {
-      OS << "Function " << F.getName() << " is a declaration, skipping.\n";
-      return;
-    }
-
-    // Create embedder for this function
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for function " << F.getName() << "\n";
-      return;
-    }
-
-    OS << "Function: " << F.getName() << "\n";
-
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel:
-      Emb->getFunctionVector().print(OS);
-      break;
-    case BasicBlockLevel:
-      for (const BasicBlock &BB : F) {
-        OS << BB.getName() << ":";
-        Emb->getBBVector(BB).print(OS);
-      }
-      break;
-    case InstructionLevel:
-      for (const Instruction &I : instructions(F)) {
-        OS << I;
-        Emb->getInstVector(I).print(OS);
-      }
-      break;
-    }
-  }
+                               EmbeddingLevel Level) const;
 };
 
 } // namespace ir2vec
@@ -283,21 +140,10 @@ class MIR2VecTool {
   std::unique_ptr<MIRVocabulary> Vocab;
 
 public:
-  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+  explicit MIR2VecTool(MachineModuleInfo &MMI);
 
   /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
-  bool initializeVocabulary(const Module &M) {
-    MIR2VecVocabProvider Provider(MMI);
-    auto VocabOrErr = Provider.getVocabulary(M);
-    if (!VocabOrErr) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to load MIR2Vec vocabulary - "
-          << toString(VocabOrErr.takeError()) << "\n";
-      return false;
-    }
-    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-    return true;
-  }
+  bool initializeVocabulary(const Module &M);
 
   /// Initialize vocabulary with layout information only.
   /// This creates a minimal vocabulary with correct layout but no actual
@@ -308,215 +154,36 @@ class MIR2VecTool {
   ///
   /// FIXME: Use --target option to get target info directly, avoiding the need
   /// to parse machine functions for pre-training operations.
-  bool initializeVocabularyForLayout(const Module &M) {
-    for (const Function &F : M.getFunctionDefs()) {
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF)
-        continue;
-
-      const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
-      const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
-      const MachineRegisterInfo &MRI = MF->getRegInfo();
-
-      auto VocabOrErr =
-          MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
-      if (!VocabOrErr) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to create dummy vocabulary - "
-            << toString(VocabOrErr.takeError()) << "\n";
-        return false;
-      }
-      Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-      return true;
-    }
-
-    WithColor::error(errs(), ToolName)
-        << "No machine functions found to initialize vocabulary\n";
-    return false;
-  }
+  bool initializeVocabularyForLayout(const Module &M);
 
   /// Get triplets for a single machine function
   /// Returns TripletResult containing MaxRelation and vector of Triplets
-  TripletResult generateTriplets(const MachineFunction &MF) const {
-    TripletResult Result;
-    Result.MaxRelation = MIRNextRelation;
-
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return Result;
-    }
-
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-    for (const MachineBasicBlock &MBB : MF) {
-      for (const MachineInstr &MI : MBB) {
-        // Skip debug instructions
-        if (MI.isDebugInstr())
-          continue;
-
-        // Get opcode entity ID
-        unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
-          LLVM_DEBUG(dbgs()
-                     << Vocab->getStringKey(PrevOpcode) << '\t'
-                     << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
-        }
-
-        // Add "Arg" relationships for operands
-        unsigned ArgIndex = 0;
-        for (const MachineOperand &MO : MI.operands()) {
-          auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
-          unsigned RelationID = MIRArgRelation + ArgIndex;
-          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
-          LLVM_DEBUG({
-            std::string OperandStr = Vocab->getStringKey(OperandID);
-            dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
-                   << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
-
-        // Update MaxRelation if there were operands
-        if (ArgIndex > 0)
-          Result.MaxRelation =
-              std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
-
-        PrevOpcode = OpcodeID;
-        HasPrevOpcode = true;
-      }
-    }
-
-    return Result;
-  }
+  TripletResult generateTriplets(const MachineFunction &MF) const;
 
   /// Get triplets for the entire module
   /// Returns TripletResult containing aggregated MaxRelation and all Triplets
-  TripletResult generateTriplets(const Module &M) const {
-    TripletResult Result;
-    Result.MaxRelation = MIRNextRelation;
-
-    for (const Function &F : M.getFunctionDefs()) {
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      TripletResult FuncResult = generateTriplets(*MF);
-      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
-                             FuncResult.Triplets.end());
-    }
-
-    return Result;
-  }
+  TripletResult generateTriplets(const Module &M) const;
 
   /// Collect triplets for the module and write to output stream
   /// Output format: MAX_RELATION=N header followed by relationships
-  void writeTripletsToStream(const Module &M, raw_ostream &OS) const {
-    auto Result = generateTriplets(M);
-    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets)
-      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-  }
+  void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
 
   /// Generate entity mappings for the entire vocabulary
-  EntityList collectEntityMappings() const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary must be initialized for entity mappings.\n";
-      return {};
-    }
-
-    const unsigned EntityCount = Vocab->getCanonicalSize();
-    EntityList Result;
-    for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      Result.push_back(Vocab->getStringKey(EntityID));
-
-    return Result;
-  }
+  EntityList collectEntityMappings() const;
 
   /// Generate entity mappings and write to output stream
-  void writeEntitiesToStream(raw_ostream &OS) const {
-    auto Entities = collectEntityMappings();
-    if (Entities.empty())
-      return;
-
-    OS << Entities.size() << "\n";
-    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
-      OS << Entities[EntityID] << '\t' << EntityID << '\n';
-  }
+  void writeEntitiesToStream(raw_ostream &OS) const;
 
   /// Generate embeddings for all machine functions in the module
   void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
-                               EmbeddingLevel Level) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M.getFunctionDefs()) {
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      writeEmbeddingsToStream(*MF, OS, Level);
-    }
-  }
+                               EmbeddingLevel Level) const;
 
   /// Generate embeddings for a specific machine function
   void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
-                               EmbeddingLevel Level) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
-
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
-      return;
-    }
-
-    OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
-
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel:
-      OS << "Function vector: ";
-      Emb->getMFunctionVector().print(OS);
-      break;
-    case BasicBlockLevel:
-      OS << "Basic block vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        OS << "MBB " << MBB.getName() << ": ";
-        Emb->getMBBVector(MBB).print(OS);
-      }
-      break;
-    case InstructionLevel:
-      OS << "Instruction vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        for (const MachineInstr &MI : MBB) {
-          OS << MI << " -> ";
-          Emb->getMInstVector(MI).print(OS);
-        }
-      }
-      break;
-    }
-  }
+                               EmbeddingLevel Level) const;
 
   /// Get the MIR vocabulary instance
-  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+  const MIRVocabulary *getVocabulary() const;
 };
 
 /// Helper structure to hold MIR context

>From 68109b999603d2165427e7ace22830a65e0596b8 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 16:07:39 +0530
Subject: [PATCH 4/5] nit commit - formatting fixup

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 20 +++++++++-----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a515797c0c505..1dfcbaedbd16a 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -179,10 +179,10 @@ TripletResult IR2VecTool::generateTriplets(const Function &F) const {
       // Add "Next" relationship with previous instruction
       if (HasPrevOpcode) {
         Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
-        LLVM_DEBUG(dbgs()
-                   << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << "Next\n");
+        LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
+                          << '\t'
+                          << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
+                          << '\t' << "Next\n");
       }
 
       // Add "Type" relationship
@@ -368,8 +368,7 @@ bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
     const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
     const MachineRegisterInfo &MRI = MF->getRegInfo();
 
-    auto VocabOrErr =
-        MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+    auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
     if (!VocabOrErr) {
       WithColor::error(errs(), ToolName)
           << "Failed to create dummy vocabulary - "
@@ -409,9 +408,8 @@ TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
       // Add "Next" relationship with previous instruction
       if (HasPrevOpcode) {
         Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
-        LLVM_DEBUG(dbgs()
-                   << Vocab->getStringKey(PrevOpcode) << '\t'
-                   << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+        LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
+                          << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
       }
 
       // Add "Arg" relationships for operands
@@ -422,8 +420,8 @@ TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
         Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
         LLVM_DEBUG({
           std::string OperandStr = Vocab->getStringKey(OperandID);
-          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
-                 << '\t' << "Arg" << ArgIndex << '\n';
+          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
+                 << "Arg" << ArgIndex << '\n';
         });
 
         ++ArgIndex;

>From c82d055d7d8097181eab853495ecbe7e7663d9ef Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 18 Dec 2025 19:20:22 +0530
Subject: [PATCH 5/5] Work Commit - moving constructores, and getVocabulary
 function to header file:

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 6 ------
 llvm/tools/llvm-ir2vec/llvm-ir2vec.h   | 6 +++---
 2 files changed, 3 insertions(+), 9 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1dfcbaedbd16a..6b70e09518fa7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -147,8 +147,6 @@ static cl::opt<EmbeddingLevel>
 
 namespace ir2vec {
 
-IR2VecTool::IR2VecTool(Module &M) : M(M) {}
-
 bool IR2VecTool::initializeVocabulary() {
   // Register and run the IR2Vec vocabulary analysis
   // The vocabulary file path is specified via --ir2vec-vocab-path global
@@ -343,8 +341,6 @@ Error processModule(Module &M, raw_ostream &OS) {
 
 namespace mir2vec {
 
-MIR2VecTool::MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
 bool MIR2VecTool::initializeVocabulary(const Module &M) {
   MIR2VecVocabProvider Provider(MMI);
   auto VocabOrErr = Provider.getVocabulary(M);
@@ -554,8 +550,6 @@ void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
   }
 }
 
-const MIRVocabulary *MIR2VecTool::getVocabulary() const { return Vocab.get(); }
-
 /// Setup MIR context from input file
 Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
   SMDiagnostic Err;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
index 9bcce9741c917..566c362edbd22 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h
@@ -90,7 +90,7 @@ class IR2VecTool {
   const Vocabulary *Vocab = nullptr;
 
 public:
-  explicit IR2VecTool(Module &M);
+  explicit IR2VecTool(Module &M) : M(M) {}
 
   /// Initialize the IR2Vec vocabulary analysis
   bool initializeVocabulary();
@@ -140,7 +140,7 @@ class MIR2VecTool {
   std::unique_ptr<MIRVocabulary> Vocab;
 
 public:
-  explicit MIR2VecTool(MachineModuleInfo &MMI);
+  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
 
   /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
   bool initializeVocabulary(const Module &M);
@@ -183,7 +183,7 @@ class MIR2VecTool {
                                EmbeddingLevel Level) const;
 
   /// Get the MIR vocabulary instance
-  const MIRVocabulary *getVocabulary() const;
+  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
 };
 
 /// Helper structure to hold MIR context



More information about the llvm-commits mailing list