[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 10 01:23:31 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078

>From e918ecdd0d5591802c68e1b78d5a5c34d5fa3d4f Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH 1/5] Refactoring llvm-ir2vec.cpp for better separation of
 concerns in the Tooling classes

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 431 +++++++++++++++++--------
 1 file changed, 290 insertions(+), 141 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7b8d3f093a3d1..a9af7ab005982 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -64,6 +64,7 @@
 #include "llvm/IR/PassInstrumentation.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -151,6 +152,29 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: [{bb_name, Embedding}]
+using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Instruction embeddings: [{instruction_string, Embedding}]
+using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
+
+struct Triplet {
+  unsigned Head;
+  unsigned Tail;
+  unsigned Relation;
+};
+
+struct TripletResult {
+  unsigned MaxRelation;
+  std::vector<Triplet> Triplets;
+};
+
 namespace ir2vec {
 
 /// Relation types for triplet generation
@@ -182,32 +206,14 @@ class IR2VecTool {
     return Vocab->isValid();
   }
 
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(raw_ostream &OS) const {
-    unsigned MaxRelation = NextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
+  TripletResult getTriplets(const Function &F) const {
+    TripletResult Result;
+    Result.MaxRelation = 0;
 
-    for (const Function &F : M) {
-      unsigned FuncMaxRelation = generateTriplets(F, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
     if (F.isDeclaration())
-      return 0;
+      return Result;
 
-    unsigned MaxRelation = 1;
+    unsigned MaxRelation = NextRelation;
     unsigned PrevOpcode = 0;
     bool HasPrevOpcode = false;
 
@@ -216,56 +222,139 @@ class IR2VecTool {
         unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
         unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
 
-        // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
           LLVM_DEBUG(dbgs()
-                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                     << "Next\n");
+                    << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                    << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                    << "Next\n");
         }
 
-        // Add "Type" relationship
-        OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
         LLVM_DEBUG(
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                   << '\t' << "Type\n");
+                  << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                  << '\t' << "Type\n");
 
-        // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
           unsigned OperandID = Vocabulary::getIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
-          OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({Opcode, OperandID, RelationID});
 
           LLVM_DEBUG({
             StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
                 Vocabulary::getOperandKind(U.get()));
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+                  << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
           });
 
           ++ArgIndex;
         }
-        // Only update MaxRelation if there were operands
+
         if (ArgIndex > 0) {
           MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
         }
+
         PrevOpcode = Opcode;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    Result.MaxRelation = MaxRelation;
+    return Result;
   }
 
-  /// Dump entity ID to string mappings
-  static void generateEntityMappings(raw_ostream &OS) {
+  TripletResult getTriplets() const {
+    TripletResult Result;
+    Result.MaxRelation = NextRelation;
+
+    for (const Function &F : M) {
+      TripletResult FuncResult = getTriplets(F);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(),
+                            FuncResult.Triplets.begin(),
+                            FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Generate triplets for the module
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(raw_ostream &OS) const {
+    auto Result = getTriplets();
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets) {
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+    }
+  }
+
+  static EntityList collectEntityMappings() {
     auto EntityLen = Vocabulary::getCanonicalSize();
-    OS << EntityLen << "\n";
+    EntityList Result;
+    Result.reserve(EntityLen);
+
     for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+    return Result;
+  }
+
+  /// Dump entity ID to string mappings
+  static void generateEntityMappings(raw_ostream &OS) {
+    auto Entities = collectEntityMappings();
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  }
+
+  ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    if (F.isDeclaration())
+      return {};
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    return Emb ? Emb->getFunctionVector() : Embedding{};
+  }
+
+  BBVecList getBBEmbeddings(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    BBVecList Result;
+
+    if (F.isDeclaration())
+      return Result;
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const BasicBlock &BB : F)
+      Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+    return Result;
+  }
+
+  InstVecList getInstEmbeddings(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    InstVecList Result;
+
+    if (F.isDeclaration())
+      return Result;
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const Instruction &I : instructions(F)) {
+      std::string InstStr;
+      raw_string_ostream(InstStr) << I;
+      Result.push_back({InstStr, Emb->getInstVector(I)});
+    }
+    return Result;
   }
 
   /// Generate embeddings for the entire module
@@ -282,44 +371,31 @@ class IR2VecTool {
 
   /// Generate embeddings for a single function
   void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
     if (F.isDeclaration()) {
       OS << "Function " << F.getName() << " is a declaration, skipping.\n";
       return;
     }
 
-    // Create embedder for this function
-    assert(Vocab->isValid() && "Vocabulary is not valid");
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for function " << F.getName() << "\n";
-      return;
-    }
-
     OS << "Function: " << F.getName() << "\n";
 
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      Emb->getFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      for (const BasicBlock &BB : F) {
-        OS << BB.getName() << ":";
-        Emb->getBBVector(BB).print(OS);
-      }
-      break;
-    }
-    case InstructionLevel: {
-      for (const BasicBlock &BB : F) {
-        for (const Instruction &I : BB) {
-          I.print(OS);
-          Emb->getInstVector(I).print(OS);
-        }
+    auto printListLevel = [&](const auto& list, const char* suffix_str) {
+      for (const auto& [name, embedding] : list) {
+        OS << name << suffix_str;
+        embedding.print(OS);
       }
-      break;
-    }
+    };
+
+    switch (Level) {
+      case EmbeddingLevel::FunctionLevel:
+        getFunctionEmbedding(F).print(OS);
+        break;
+      case EmbeddingLevel::BasicBlockLevel:
+        printListLevel(getBBEmbeddings(F), ":");
+        break;
+      case EmbeddingLevel::InstructionLevel:
+        printListLevel(getInstEmbeddings(F), "");
+        break;
     }
   }
 };
@@ -421,47 +497,22 @@ class MIR2VecTool {
         << "No machine functions found to initialize vocabulary\n";
     return false;
   }
-
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(const Module &M, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M.getFunctionDefs()) {
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single machine function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
+  
+  /// Get triplets for a single machine function
+  /// Returns TripletResult containing MaxRelation and vector of Triplets
+  TripletResult getTriplets(const MachineFunction &MF) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
 
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return MaxRelation;
+      return Result;
     }
 
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
+
     for (const MachineBasicBlock &MBB : MF) {
       for (const MachineInstr &MI : MBB) {
         // Skip debug instructions
@@ -473,8 +524,7 @@ class MIR2VecTool {
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
-             << '\n';
+          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
           LLVM_DEBUG(dbgs()
                      << Vocab->getStringKey(PrevOpcode) << '\t'
                      << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -485,7 +535,7 @@ class MIR2VecTool {
         for (const MachineOperand &MO : MI.operands()) {
           auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
           unsigned RelationID = MIRArgRelation + ArgIndex;
-          OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
           LLVM_DEBUG({
             std::string OperandStr = Vocab->getStringKey(OperandID);
             dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -497,28 +547,82 @@ class MIR2VecTool {
 
         // Update MaxRelation if there were operands
         if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+          Result.MaxRelation = std::max(Result.MaxRelation, 
+                                       MIRArgRelation + ArgIndex - 1);
 
         PrevOpcode = OpcodeID;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    return Result;
   }
+  
+  /// Get triplets for the entire module
+  /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+  TripletResult getTriplets(const Module &M) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
 
-  /// Generate entity mappings with vocabulary
-  void generateEntityMappings(raw_ostream &OS) const {
+    for (const Function &F : M) {
+      if (F.isDeclaration())
+        continue;
+
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        WithColor::warning(errs(), ToolName)
+            << "No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      TripletResult FuncResult = getTriplets(*MF);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(),
+                            FuncResult.Triplets.begin(),
+                            FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Generate triplets for the module and write to output stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(const Module &M, raw_ostream &OS) const {
+    auto Result = getTriplets(M);
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets) {
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+    }
+  }
+
+  /// Collect entity mappings
+  /// Returns EntityList containing all entity strings
+  EntityList collectEntityMappings() const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "Vocabulary must be initialized for entity mappings.\n";
-      return;
+      return {};
     }
 
     const unsigned EntityCount = Vocab->getCanonicalSize();
-    OS << EntityCount << "\n";
+    EntityList Result;
+    Result.reserve(EntityCount);
+
     for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocab->getStringKey(EntityID));
+
+    return Result;
+  }
+
+  /// Generate entity mappings and write to output stream
+  void generateEntityMappings(raw_ostream &OS) const {
+    auto Entities = collectEntityMappings();
+    if (Entities.empty())
+      return;
+
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
   }
 
   /// Generate embeddings for all machine functions in the module
@@ -541,48 +645,93 @@ class MIR2VecTool {
     }
   }
 
-  /// Generate embeddings for a specific machine function
-  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+  /// Get machine function embedding
+  /// Returns Embedding for the entire machine function
+  ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
+      return {};
     }
 
     auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
+    return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+  }
+
+  /// Get machine basic block embeddings
+  /// Returns BBVecList containing (name, embedding) pairs for all MBBs
+  BBVecList getMBBEmbeddings(MachineFunction &MF) const {
+    BBVecList Result;
+
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return Result;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const MachineBasicBlock &MBB : MF)
+      Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+
+    return Result;
+  }
+
+  /// Get machine instruction embeddings
+  /// Returns InstVecList containing (instruction_string, embedding) pairs
+  InstVecList getMInstEmbeddings(MachineFunction &MF) const {
+    InstVecList Result;
+
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return Result;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        std::string InstStr;
+        raw_string_ostream(InstStr) << MI;
+        Result.push_back({InstStr, Emb->getMInstVector(MI)});
+      }
+    }
+
+    return Result;
+  }
+
+  /// Generate embeddings for a specific machine function
+  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
       return;
     }
 
     OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
 
+    auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+      OS << label << ":\n";
+      for (const auto& [name, embedding] : list) {
+        OS << prefix_str << name << suffix_str;
+        embedding.print(OS);
+      }
+    };
+
     // Generate embeddings based on the specified level
     switch (Level) {
-    case FunctionLevel: {
+    case FunctionLevel:
       OS << "Function vector: ";
-      Emb->getMFunctionVector().print(OS);
+      getMFunctionEmbedding(MF).print(OS);
       break;
-    }
-    case BasicBlockLevel: {
-      OS << "Basic block vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        OS << "MBB " << MBB.getName() << ": ";
-        Emb->getMBBVector(MBB).print(OS);
-      }
+    case BasicBlockLevel:
+      printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
       break;
-    }
-    case InstructionLevel: {
-      OS << "Instruction vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        for (const MachineInstr &MI : MBB) {
-          OS << MI << " -> ";
-          Emb->getMInstVector(MI).print(OS);
-        }
-      }
+    case InstructionLevel:
+      printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
       break;
     }
-    }
   }
 
   const MIRVocabulary *getVocabulary() const { return Vocab.get(); }

>From 51f359ea0158fb4959e55091fd940780e7ac6c66 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 2 Dec 2025 12:42:28 +0530
Subject: [PATCH 2/5] Debug Commit - Clang Format fixup

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 59 +++++++++++++-------------
 1 file changed, 29 insertions(+), 30 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a9af7ab005982..2511983c5d490 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -58,13 +58,13 @@
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassInstrumentation.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
-#include "llvm/IR/InstIterator.h"
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -225,16 +225,16 @@ class IR2VecTool {
         if (HasPrevOpcode) {
           Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
           LLVM_DEBUG(dbgs()
-                    << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                    << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                    << "Next\n");
+                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                     << "Next\n");
         }
 
         Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
         LLVM_DEBUG(
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                  << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                  << '\t' << "Type\n");
+                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                   << '\t' << "Type\n");
 
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
@@ -246,7 +246,7 @@ class IR2VecTool {
             StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
                 Vocabulary::getOperandKind(U.get()));
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                  << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
           });
 
           ++ArgIndex;
@@ -272,9 +272,8 @@ class IR2VecTool {
     for (const Function &F : M) {
       TripletResult FuncResult = getTriplets(F);
       Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(),
-                            FuncResult.Triplets.begin(),
-                            FuncResult.Triplets.end());
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
     }
 
     return Result;
@@ -379,23 +378,23 @@ class IR2VecTool {
 
     OS << "Function: " << F.getName() << "\n";
 
-    auto printListLevel = [&](const auto& list, const char* suffix_str) {
-      for (const auto& [name, embedding] : list) {
+    auto printListLevel = [&](const auto &list, const char *suffix_str) {
+      for (const auto &[name, embedding] : list) {
         OS << name << suffix_str;
         embedding.print(OS);
       }
     };
 
     switch (Level) {
-      case EmbeddingLevel::FunctionLevel:
-        getFunctionEmbedding(F).print(OS);
-        break;
-      case EmbeddingLevel::BasicBlockLevel:
-        printListLevel(getBBEmbeddings(F), ":");
-        break;
-      case EmbeddingLevel::InstructionLevel:
-        printListLevel(getInstEmbeddings(F), "");
-        break;
+    case EmbeddingLevel::FunctionLevel:
+      getFunctionEmbedding(F).print(OS);
+      break;
+    case EmbeddingLevel::BasicBlockLevel:
+      printListLevel(getBBEmbeddings(F), ":");
+      break;
+    case EmbeddingLevel::InstructionLevel:
+      printListLevel(getInstEmbeddings(F), "");
+      break;
     }
   }
 };
@@ -497,7 +496,7 @@ class MIR2VecTool {
         << "No machine functions found to initialize vocabulary\n";
     return false;
   }
-  
+
   /// Get triplets for a single machine function
   /// Returns TripletResult containing MaxRelation and vector of Triplets
   TripletResult getTriplets(const MachineFunction &MF) const {
@@ -547,8 +546,8 @@ class MIR2VecTool {
 
         // Update MaxRelation if there were operands
         if (ArgIndex > 0)
-          Result.MaxRelation = std::max(Result.MaxRelation, 
-                                       MIRArgRelation + ArgIndex - 1);
+          Result.MaxRelation =
+              std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
 
         PrevOpcode = OpcodeID;
         HasPrevOpcode = true;
@@ -557,7 +556,7 @@ class MIR2VecTool {
 
     return Result;
   }
-  
+
   /// Get triplets for the entire module
   /// Returns TripletResult containing aggregated MaxRelation and all Triplets
   TripletResult getTriplets(const Module &M) const {
@@ -577,9 +576,8 @@ class MIR2VecTool {
 
       TripletResult FuncResult = getTriplets(*MF);
       Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
-      Result.Triplets.insert(Result.Triplets.end(),
-                            FuncResult.Triplets.begin(),
-                            FuncResult.Triplets.end());
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
     }
 
     return Result;
@@ -711,9 +709,10 @@ class MIR2VecTool {
 
     OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
 
-    auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+    auto printListLevel = [&](const auto &list, const char *label,
+                              const char *prefix_str, const char *suffix_str) {
       OS << label << ":\n";
-      for (const auto& [name, embedding] : list) {
+      for (const auto &[name, embedding] : list) {
         OS << prefix_str << name << suffix_str;
         embedding.print(OS);
       }

>From 2786c97c17deefcbef3ba92f3061874e166211c4 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 2 Dec 2025 16:44:25 +0530
Subject: [PATCH 3/5] Debug Commit - Resolving comments. To be squashed later

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 186 +++++++++++++------------
 1 file changed, 98 insertions(+), 88 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 2511983c5d490..d3afbbc6f87e6 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -152,18 +152,6 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
-/// Basic block embeddings: [{bb_name, Embedding}]
-using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
-
-/// Instruction embeddings: [{instruction_string, Embedding}]
-using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
 struct Triplet {
   unsigned Head;
   unsigned Tail;
@@ -176,6 +164,18 @@ struct TripletResult {
 };
 
 namespace ir2vec {
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: [{bb_ptr, Embedding}]
+using BBVecList = std::vector<std::pair<const BasicBlock *, ir2vec::Embedding>>;
+
+/// Instruction embeddings: [{instruction_ptr, Embedding}]
+using InstVecList =
+    std::vector<std::pair<const Instruction *, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
 
 /// Relation types for triplet generation
 enum RelationType {
@@ -206,13 +206,13 @@ class IR2VecTool {
     return Vocab->isValid();
   }
 
-  TripletResult getTriplets(const Function &F) const {
+  TripletResult generateTriplets(const Function &F) const {
+    if (F.isDeclaration())
+      return {};
+
     TripletResult Result;
     Result.MaxRelation = 0;
 
-    if (F.isDeclaration())
-      return Result;
-
     unsigned MaxRelation = NextRelation;
     unsigned PrevOpcode = 0;
     bool HasPrevOpcode = false;
@@ -222,6 +222,7 @@ class IR2VecTool {
         unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
         unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
 
+        // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
           Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
           LLVM_DEBUG(dbgs()
@@ -230,12 +231,14 @@ class IR2VecTool {
                      << "Next\n");
         }
 
+        // Add "Type" relationship
         Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
         LLVM_DEBUG(
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
                    << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
                    << '\t' << "Type\n");
 
+        // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
           unsigned OperandID = Vocabulary::getIndex(*U.get());
@@ -251,10 +254,9 @@ class IR2VecTool {
 
           ++ArgIndex;
         }
-
-        if (ArgIndex > 0) {
+        // Only update MaxRelation if there were operands
+        if (ArgIndex > 0)
           MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-        }
 
         PrevOpcode = Opcode;
         HasPrevOpcode = true;
@@ -265,12 +267,14 @@ class IR2VecTool {
     return Result;
   }
 
-  TripletResult getTriplets() const {
+  TripletResult generateTriplets() const {
     TripletResult Result;
     Result.MaxRelation = NextRelation;
 
     for (const Function &F : M) {
-      TripletResult FuncResult = getTriplets(F);
+      if (F.isDeclaration())
+        continue;
+      TripletResult FuncResult = generateTriplets(F);
       Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
       Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
                              FuncResult.Triplets.end());
@@ -282,18 +286,15 @@ class IR2VecTool {
   /// Generate triplets for the module
   /// Output format: MAX_RELATION=N header followed by relationships
   void generateTriplets(raw_ostream &OS) const {
-    auto Result = getTriplets();
+    auto Result = generateTriplets();
     OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets) {
+    for (const auto &T : Result.Triplets)
       OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-    }
   }
 
   static EntityList collectEntityMappings() {
     auto EntityLen = Vocabulary::getCanonicalSize();
     EntityList Result;
-    Result.reserve(EntityLen);
-
     for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
       Result.push_back(Vocabulary::getStringKey(EntityID).str());
 
@@ -315,23 +316,27 @@ class IR2VecTool {
       return {};
 
     auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    return Emb ? Emb->getFunctionVector() : Embedding{};
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << F.getName() << "\n";
+      return {};
+    }
+    return Emb->getFunctionVector();
   }
 
   BBVecList getBBEmbeddings(const Function &F) const {
     assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
 
-    BBVecList Result;
-
     if (F.isDeclaration())
-      return Result;
+      return {};
 
     auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
     if (!Emb)
-      return Result;
+      return {};
 
+    BBVecList Result;
     for (const BasicBlock &BB : F)
-      Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+      Result.push_back({&BB, Emb->getBBVector(BB)});
 
     return Result;
   }
@@ -339,20 +344,17 @@ class IR2VecTool {
   InstVecList getInstEmbeddings(const Function &F) const {
     assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
 
-    InstVecList Result;
-
     if (F.isDeclaration())
-      return Result;
+      return {};
 
     auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
     if (!Emb)
-      return Result;
+      return {};
+
+    InstVecList Result;
+    for (const Instruction &I : instructions(F))
+      Result.push_back({&I, Emb->getInstVector(I)});
 
-    for (const Instruction &I : instructions(F)) {
-      std::string InstStr;
-      raw_string_ostream(InstStr) << I;
-      Result.push_back({InstStr, Emb->getInstVector(I)});
-    }
     return Result;
   }
 
@@ -378,22 +380,21 @@ class IR2VecTool {
 
     OS << "Function: " << F.getName() << "\n";
 
-    auto printListLevel = [&](const auto &list, const char *suffix_str) {
-      for (const auto &[name, embedding] : list) {
-        OS << name << suffix_str;
-        embedding.print(OS);
-      }
-    };
-
     switch (Level) {
     case EmbeddingLevel::FunctionLevel:
       getFunctionEmbedding(F).print(OS);
       break;
     case EmbeddingLevel::BasicBlockLevel:
-      printListLevel(getBBEmbeddings(F), ":");
+      for (const auto &[BB, embedding] : getBBEmbeddings(F)) {
+        OS << BB->getName() << ":";
+        embedding.print(OS);
+      }
       break;
     case EmbeddingLevel::InstructionLevel:
-      printListLevel(getInstEmbeddings(F), "");
+      for (const auto &[I, embedding] : getInstEmbeddings(F)) {
+        OS << *I;
+        embedding.print(OS);
+      }
       break;
     }
   }
@@ -430,6 +431,19 @@ Error processModule(Module &M, raw_ostream &OS) {
 } // namespace ir2vec
 
 namespace mir2vec {
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
+/// Machine basic block embeddings: [{mbb_ptr, Embedding}]
+using MBBVecList =
+    std::vector<std::pair<const MachineBasicBlock *, ir2vec::Embedding>>;
+
+/// Machine instruction embeddings: [{minstr_ptr, Embedding}]
+using MInstVecList =
+    std::vector<std::pair<const MachineInstr *, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
 
 /// Relation types for MIR2Vec triplet generation
 enum MIRRelationType {
@@ -499,7 +513,7 @@ class MIR2VecTool {
 
   /// Get triplets for a single machine function
   /// Returns TripletResult containing MaxRelation and vector of Triplets
-  TripletResult getTriplets(const MachineFunction &MF) const {
+  TripletResult generateTriplets(const MachineFunction &MF) const {
     TripletResult Result;
     Result.MaxRelation = MIRNextRelation;
 
@@ -559,7 +573,7 @@ class MIR2VecTool {
 
   /// Get triplets for the entire module
   /// Returns TripletResult containing aggregated MaxRelation and all Triplets
-  TripletResult getTriplets(const Module &M) const {
+  TripletResult generateTriplets(const Module &M) const {
     TripletResult Result;
     Result.MaxRelation = MIRNextRelation;
 
@@ -574,7 +588,7 @@ class MIR2VecTool {
         continue;
       }
 
-      TripletResult FuncResult = getTriplets(*MF);
+      TripletResult FuncResult = generateTriplets(*MF);
       Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
       Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
                              FuncResult.Triplets.end());
@@ -586,11 +600,10 @@ class MIR2VecTool {
   /// Generate triplets for the module and write to output stream
   /// Output format: MAX_RELATION=N header followed by relationships
   void generateTriplets(const Module &M, raw_ostream &OS) const {
-    auto Result = getTriplets(M);
+    auto Result = generateTriplets(M);
     OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
-    for (const auto &T : Result.Triplets) {
+    for (const auto &T : Result.Triplets)
       OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
-    }
   }
 
   /// Collect entity mappings
@@ -604,8 +617,6 @@ class MIR2VecTool {
 
     const unsigned EntityCount = Vocab->getCanonicalSize();
     EntityList Result;
-    Result.reserve(EntityCount);
-
     for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
       Result.push_back(Vocab->getStringKey(EntityID));
 
@@ -652,49 +663,49 @@ class MIR2VecTool {
     }
 
     auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << MF.getName() << "\n";
+      return {};
+    }
+    return Emb->getMFunctionVector();
   }
 
   /// Get machine basic block embeddings
-  /// Returns BBVecList containing (name, embedding) pairs for all MBBs
-  BBVecList getMBBEmbeddings(MachineFunction &MF) const {
-    BBVecList Result;
-
+  /// Returns MBBVecList containing (name, embedding) pairs for all MBBs
+  MBBVecList getMBBEmbeddings(MachineFunction &MF) const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return Result;
+      return {};
     }
 
     auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
     if (!Emb)
-      return Result;
+      return {};
 
+    MBBVecList Result;
     for (const MachineBasicBlock &MBB : MF)
-      Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+      Result.push_back({&MBB, Emb->getMBBVector(MBB)});
 
     return Result;
   }
 
   /// Get machine instruction embeddings
-  /// Returns InstVecList containing (instruction_string, embedding) pairs
-  InstVecList getMInstEmbeddings(MachineFunction &MF) const {
-    InstVecList Result;
-
+  /// Returns MInstVecList containing (instruction_string, embedding) pairs
+  MInstVecList getMInstEmbeddings(MachineFunction &MF) const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return Result;
+      return {};
     }
 
     auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
     if (!Emb)
-      return Result;
+      return {};
 
+    MInstVecList Result;
     for (const MachineBasicBlock &MBB : MF) {
-      for (const MachineInstr &MI : MBB) {
-        std::string InstStr;
-        raw_string_ostream(InstStr) << MI;
-        Result.push_back({InstStr, Emb->getMInstVector(MI)});
-      }
+      for (const MachineInstr &MI : MBB)
+        Result.push_back({&MI, Emb->getMInstVector(MI)});
     }
 
     return Result;
@@ -709,15 +720,6 @@ class MIR2VecTool {
 
     OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
 
-    auto printListLevel = [&](const auto &list, const char *label,
-                              const char *prefix_str, const char *suffix_str) {
-      OS << label << ":\n";
-      for (const auto &[name, embedding] : list) {
-        OS << prefix_str << name << suffix_str;
-        embedding.print(OS);
-      }
-    };
-
     // Generate embeddings based on the specified level
     switch (Level) {
     case FunctionLevel:
@@ -725,10 +727,18 @@ class MIR2VecTool {
       getMFunctionEmbedding(MF).print(OS);
       break;
     case BasicBlockLevel:
-      printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
+      OS << "Basic block vectors:\n";
+      for (const auto &[MBB, embedding] : getMBBEmbeddings(MF)) {
+        OS << "MBB " << MBB->getName() << ": ";
+        embedding.print(OS);
+      }
       break;
     case InstructionLevel:
-      printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
+      OS << "Instruction vectors:\n";
+      for (const auto &[MI, embedding] : getMInstEmbeddings(MF)) {
+        OS << *MI << " -> ";
+        embedding.print(OS);
+      }
       break;
     }
   }

>From 555f68de511ab324fb4e2618e3626e49ce0b5f43 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 8 Dec 2025 16:37:58 +0530
Subject: [PATCH 4/5] Debug Commit - Resolving comments. To be squashed later

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 252 +++++++++++--------------
 1 file changed, 109 insertions(+), 143 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index d3afbbc6f87e6..5b62142652990 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -152,31 +152,25 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
 struct Triplet {
-  unsigned Head;
-  unsigned Tail;
-  unsigned Relation;
+  unsigned Head = 0;     ///< Index of the head entity in the entity list
+  unsigned Tail = 0;     ///< Index of the tail entity in the entity list
+  unsigned Relation = 0; ///< Relation type (see RelationType enum)
 };
 
+/// Result structure containing all generated triplets and metadata
 struct TripletResult {
-  unsigned MaxRelation;
-  std::vector<Triplet> Triplets;
+  unsigned MaxRelation =
+      0; ///< Highest relation index used (for ArgRelation + N)
+  std::vector<Triplet> Triplets; ///< Collection of all generated triplets
 };
 
-namespace ir2vec {
 /// Entity mappings: [entity_name]
 using EntityList = std::vector<std::string>;
 
-/// Basic block embeddings: [{bb_ptr, Embedding}]
-using BBVecList = std::vector<std::pair<const BasicBlock *, ir2vec::Embedding>>;
-
-/// Instruction embeddings: [{instruction_ptr, Embedding}]
-using InstVecList =
-    std::vector<std::pair<const Instruction *, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
+namespace ir2vec {
 /// Relation types for triplet generation
 enum RelationType {
   TypeRelation = 0, ///< Instruction to type relationship
@@ -206,6 +200,8 @@ class IR2VecTool {
     return Vocab->isValid();
   }
 
+  /// Generate triplets for a single function
+  /// Returns the maximum relation ID used in this function
   TripletResult generateTriplets(const Function &F) const {
     if (F.isDeclaration())
       return {};
@@ -267,13 +263,12 @@ class IR2VecTool {
     return Result;
   }
 
+  /// Get triplets for the entire module
   TripletResult generateTriplets() const {
     TripletResult Result;
     Result.MaxRelation = NextRelation;
 
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
+    for (const Function &F : M.getFunctionDefs()) {
       TripletResult FuncResult = generateTriplets(F);
       Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
       Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
@@ -283,7 +278,7 @@ class IR2VecTool {
     return Result;
   }
 
-  /// Generate triplets for the module
+  /// Collect triplets for the module and dump output to stream
   /// Output format: MAX_RELATION=N header followed by relationships
   void generateTriplets(raw_ostream &OS) const {
     auto Result = generateTriplets();
@@ -292,6 +287,8 @@ class IR2VecTool {
       OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
   }
 
+  /// Generate entity mappings for the entire vocabulary
+  /// Returns EntityList containing all entity strings
   static EntityList collectEntityMappings() {
     auto EntityLen = Vocabulary::getCanonicalSize();
     EntityList Result;
@@ -309,52 +306,26 @@ class IR2VecTool {
       OS << Entities[EntityID] << '\t' << EntityID << '\n';
   }
 
-  ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
-    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
-    if (F.isDeclaration())
-      return {};
-
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << F.getName() << "\n";
-      return {};
-    }
-    return Emb->getFunctionVector();
+  /// Get function-level embedding using provided embedder
+  ir2vec::Embedding getFunctionEmbedding(const Embedder &Emb) const {
+    return Emb.getFunctionVector();
   }
 
-  BBVecList getBBEmbeddings(const Function &F) const {
-    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
-    if (F.isDeclaration())
-      return {};
-
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb)
-      return {};
-
-    BBVecList Result;
+  /// Get basic block embeddings for the given function
+  BBEmbeddingsMap getBBEmbeddings(const Function &F,
+                                  const Embedder &Emb) const {
+    BBEmbeddingsMap Result;
     for (const BasicBlock &BB : F)
-      Result.push_back({&BB, Emb->getBBVector(BB)});
-
+      Result.try_emplace(&BB, Emb.getBBVector(BB));
     return Result;
   }
 
-  InstVecList getInstEmbeddings(const Function &F) const {
-    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
-
-    if (F.isDeclaration())
-      return {};
-
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb)
-      return {};
-
-    InstVecList Result;
+  /// Get instruction embeddings using provided embedder
+  InstEmbeddingsMap getInstEmbeddings(const Function &F,
+                                      const Embedder &Emb) const {
+    InstEmbeddingsMap Result;
     for (const Instruction &I : instructions(F))
-      Result.push_back({&I, Emb->getInstVector(I)});
-
+      Result.try_emplace(&I, Emb.getInstVector(I));
     return Result;
   }
 
@@ -366,40 +337,62 @@ class IR2VecTool {
       return;
     }
 
-    for (const Function &F : M)
+    for (const Function &F : M.getFunctionDefs())
       generateEmbeddings(F, OS);
   }
 
   /// Generate embeddings for a single function
   void generateEmbeddings(const Function &F, raw_ostream &OS) const {
-    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+    if (!Vocab || !Vocab->isValid()) {
+      WithColor::error(errs(), ToolName)
+          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+      return;
+    }
+
     if (F.isDeclaration()) {
       OS << "Function " << F.getName() << " is a declaration, skipping.\n";
       return;
     }
 
+    // Create embedder once for the function
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << F.getName() << "\n";
+      return;
+    }
+
     OS << "Function: " << F.getName() << "\n";
 
     switch (Level) {
     case EmbeddingLevel::FunctionLevel:
-      getFunctionEmbedding(F).print(OS);
+      getFunctionEmbedding(*Emb).print(OS);
       break;
-    case EmbeddingLevel::BasicBlockLevel:
-      for (const auto &[BB, embedding] : getBBEmbeddings(F)) {
-        OS << BB->getName() << ":";
-        embedding.print(OS);
+    case EmbeddingLevel::BasicBlockLevel: {
+      const auto EmbMap = getBBEmbeddings(F, *Emb);
+      for (const BasicBlock &BB : F) {
+        if (auto It = EmbMap.find(&BB); It != EmbMap.end()) {
+          OS << BB.getName() << ":";
+          It->second.print(OS);
+        }
       }
       break;
-    case EmbeddingLevel::InstructionLevel:
-      for (const auto &[I, embedding] : getInstEmbeddings(F)) {
-        OS << *I;
-        embedding.print(OS);
+    }
+    case EmbeddingLevel::InstructionLevel: {
+      const auto EmbMap = getInstEmbeddings(F, *Emb);
+      for (const Instruction &I : instructions(F)) {
+        if (auto It = EmbMap.find(&I); It != EmbMap.end()) {
+          OS << I;
+          It->second.print(OS);
+        }
       }
       break;
     }
+    }
   }
 };
 
+/// Process the module and generate output based on selected subcommand
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
 
@@ -431,20 +424,6 @@ Error processModule(Module &M, raw_ostream &OS) {
 } // namespace ir2vec
 
 namespace mir2vec {
-/// Entity mappings: [entity_name]
-using EntityList = std::vector<std::string>;
-
-/// Machine basic block embeddings: [{mbb_ptr, Embedding}]
-using MBBVecList =
-    std::vector<std::pair<const MachineBasicBlock *, ir2vec::Embedding>>;
-
-/// Machine instruction embeddings: [{minstr_ptr, Embedding}]
-using MInstVecList =
-    std::vector<std::pair<const MachineInstr *, ir2vec::Embedding>>;
-
-/// Function embeddings: [Embedding]
-using FuncVecList = std::vector<ir2vec::Embedding>;
-
 /// Relation types for MIR2Vec triplet generation
 enum MIRRelationType {
   MIRNextRelation = 0, ///< Sequential instruction relationship
@@ -577,10 +556,7 @@ class MIR2VecTool {
     TripletResult Result;
     Result.MaxRelation = MIRNextRelation;
 
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
-
+    for (const Function &F : M.getFunctionDefs()) {
       MachineFunction *MF = MMI.getMachineFunction(F);
       if (!MF) {
         WithColor::warning(errs(), ToolName)
@@ -606,8 +582,7 @@ class MIR2VecTool {
       OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
   }
 
-  /// Collect entity mappings
-  /// Returns EntityList containing all entity strings
+  /// Generate entity mappings for the entire vocabulary
   EntityList collectEntityMappings() const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
@@ -642,7 +617,6 @@ class MIR2VecTool {
     }
 
     for (const Function &F : M.getFunctionDefs()) {
-
       MachineFunction *MF = MMI.getMachineFunction(F);
       if (!MF) {
         WithColor::warning(errs(), ToolName)
@@ -656,58 +630,31 @@ class MIR2VecTool {
 
   /// Get machine function embedding
   /// Returns Embedding for the entire machine function
-  ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return {};
-    }
-
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
-      return {};
-    }
-    return Emb->getMFunctionVector();
+  ir2vec::Embedding getMFunctionEmbedding(const MIREmbedder &Emb) const {
+    return Emb.getMFunctionVector();
   }
 
-  /// Get machine basic block embeddings
-  /// Returns MBBVecList containing (name, embedding) pairs for all MBBs
-  MBBVecList getMBBEmbeddings(MachineFunction &MF) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return {};
-    }
-
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb)
-      return {};
-
-    MBBVecList Result;
+  /// Get machine basic block embeddings using provided embedder
+  /// Returns MachineBlockEmbeddingsMap containing MBB pointer to embedding
+  /// mappings
+  MachineBlockEmbeddingsMap getMBBEmbeddings(MachineFunction &MF,
+                                             const MIREmbedder &Emb) const {
+    MachineBlockEmbeddingsMap Result;
     for (const MachineBasicBlock &MBB : MF)
-      Result.push_back({&MBB, Emb->getMBBVector(MBB)});
-
+      Result.try_emplace(&MBB, Emb.getMBBVector(MBB));
     return Result;
   }
 
-  /// Get machine instruction embeddings
-  /// Returns MInstVecList containing (instruction_string, embedding) pairs
-  MInstVecList getMInstEmbeddings(MachineFunction &MF) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return {};
-    }
-
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb)
-      return {};
-
-    MInstVecList Result;
+  /// Get machine instruction embeddings using provided embedder
+  /// Returns MachineInstEmbeddingsMap containing MI pointer to embedding
+  /// mappings
+  MachineInstEmbeddingsMap getMInstEmbeddings(MachineFunction &MF,
+                                              const MIREmbedder &Emb) const {
+    MachineInstEmbeddingsMap Result;
     for (const MachineBasicBlock &MBB : MF) {
       for (const MachineInstr &MI : MBB)
-        Result.push_back({&MI, Emb->getMInstVector(MI)});
+        Result.try_emplace(&MI, Emb.getMInstVector(MI));
     }
-
     return Result;
   }
 
@@ -718,31 +665,50 @@ class MIR2VecTool {
       return;
     }
 
+    // Create embedder once for the machine function
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << MF.getName() << "\n";
+      return;
+    }
+
     OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
 
     // Generate embeddings based on the specified level
     switch (Level) {
     case FunctionLevel:
       OS << "Function vector: ";
-      getMFunctionEmbedding(MF).print(OS);
+      getMFunctionEmbedding(*Emb).print(OS);
       break;
-    case BasicBlockLevel:
+    case BasicBlockLevel: {
       OS << "Basic block vectors:\n";
-      for (const auto &[MBB, embedding] : getMBBEmbeddings(MF)) {
-        OS << "MBB " << MBB->getName() << ": ";
-        embedding.print(OS);
+      const auto MBBEmbMap = getMBBEmbeddings(MF, *Emb);
+      for (const MachineBasicBlock &MBB : MF) {
+        if (auto It = MBBEmbMap.find(&MBB); It != MBBEmbMap.end()) {
+          OS << "MBB " << MBB.getName() << ": ";
+          It->second.print(OS);
+        }
       }
       break;
-    case InstructionLevel:
+    }
+    case InstructionLevel: {
       OS << "Instruction vectors:\n";
-      for (const auto &[MI, embedding] : getMInstEmbeddings(MF)) {
-        OS << *MI << " -> ";
-        embedding.print(OS);
+      const auto MInstEmbMap = getMInstEmbeddings(MF, *Emb);
+      for (const MachineBasicBlock &MBB : MF) {
+        for (const MachineInstr &MI : MBB) {
+          if (auto It = MInstEmbMap.find(&MI); It != MInstEmbMap.end()) {
+            OS << MI << " -> ";
+            It->second.print(OS);
+          }
+        }
       }
       break;
     }
+    }
   }
 
+  /// Get the MIR vocabulary instance
   const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
 };
 

>From 087a84b52bf14d47b27201c6617413b37844870b Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 10 Dec 2025 14:51:35 +0530
Subject: [PATCH 5/5] Resolving nit comment. To be squashed later

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 5b62142652990..f106ee7693a1e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -201,7 +201,9 @@ class IR2VecTool {
   }
 
   /// Generate triplets for a single function
-  /// Returns the maximum relation ID used in this function
+  /// Returns a TripletResult with:
+  ///   - Triplets: vector of all (subject, object, relation) tuples
+  ///   - MaxRelation: highest Arg relation ID used, or NextRelation if none
   TripletResult generateTriplets(const Function &F) const {
     if (F.isDeclaration())
       return {};



More information about the llvm-commits mailing list