[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 30 23:02:42 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078

>From 6f57fd49eedeb0a2743c29cbc8d9762add560993 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH] Refactoring llvm-ir2vec.cpp for better separation of concerns
 in the Tooling classes

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 433 +++++++++++++++++--------
 1 file changed, 290 insertions(+), 143 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7402782bfd404..ba94205193495 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -64,6 +64,7 @@
 #include "llvm/IR/PassInstrumentation.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -151,6 +152,29 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
+/// Entity mappings: entity_id -> entity_name
+using EntityList = std::vector<std::string>;
+
+/// Basic block embeddings: bb_name -> Embedding
+using BBVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Instruction embeddings: instruction_string -> Embedding
+using InstVecList = std::vector<std::pair<std::string, ir2vec::Embedding>>;
+
+/// Function embeddings: [Embedding]
+using FuncVecList = std::vector<ir2vec::Embedding>;
+
+struct Triplet {
+  unsigned Head;
+  unsigned Tail;
+  unsigned Relation;
+};
+
+struct TripletResult {
+  unsigned MaxRelation;
+  std::vector<Triplet> Triplets;
+};
+
 namespace ir2vec {
 
 /// Relation types for triplet generation
@@ -182,32 +206,14 @@ class IR2VecTool {
     return Vocab->isValid();
   }
 
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(raw_ostream &OS) const {
-    unsigned MaxRelation = NextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
+  TripletResult getTriplets(const Function &F) const {
+    TripletResult Result;
+    Result.MaxRelation = 0;
 
-    for (const Function &F : M) {
-      unsigned FuncMaxRelation = generateTriplets(F, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
     if (F.isDeclaration())
-      return 0;
+      return Result;
 
-    unsigned MaxRelation = 1;
+    unsigned MaxRelation = NextRelation;
     unsigned PrevOpcode = 0;
     bool HasPrevOpcode = false;
 
@@ -216,56 +222,139 @@ class IR2VecTool {
         unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
         unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
 
-        // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
           LLVM_DEBUG(dbgs()
-                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                     << "Next\n");
+                    << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                    << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                    << "Next\n");
         }
 
-        // Add "Type" relationship
-        OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
         LLVM_DEBUG(
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                   << '\t' << "Type\n");
+                  << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                  << '\t' << "Type\n");
 
-        // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
           unsigned OperandID = Vocabulary::getIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
-          OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({Opcode, OperandID, RelationID});
 
           LLVM_DEBUG({
             StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
                 Vocabulary::getOperandKind(U.get()));
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+                  << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
           });
 
           ++ArgIndex;
         }
-        // Only update MaxRelation if there were operands
+
         if (ArgIndex > 0) {
           MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
         }
+
         PrevOpcode = Opcode;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    Result.MaxRelation = MaxRelation;
+    return Result;
   }
 
-  /// Dump entity ID to string mappings
-  static void generateEntityMappings(raw_ostream &OS) {
+  TripletResult getTriplets() const {
+    TripletResult Result;
+    Result.MaxRelation = NextRelation;
+
+    for (const Function &F : M) {
+      TripletResult FuncResult = getTriplets(F);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(),
+                            FuncResult.Triplets.begin(),
+                            FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Generate triplets for the module
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(raw_ostream &OS) const {
+    auto Result = getTriplets();
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets) {
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+    }
+  }
+
+  static EntityList collectEntityMappings() {
     auto EntityLen = Vocabulary::getCanonicalSize();
-    OS << EntityLen << "\n";
+    EntityList Result;
+    Result.reserve(EntityLen);
+
     for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+    return Result;
+  }
+
+  /// Dump entity ID to string mappings
+  static void generateEntityMappings(raw_ostream &OS) {
+    auto Entities = collectEntityMappings();
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  }
+
+  ir2vec::Embedding getFunctionEmbedding(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    if (F.isDeclaration())
+      return {};
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    return Emb ? Emb->getFunctionVector() : Embedding{};
+  }
+
+  BBVecList getBBEmbeddings(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    BBVecList Result;
+
+    if (F.isDeclaration())
+      return Result;
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const BasicBlock &BB : F)
+      Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+    return Result;
+  }
+
+  InstVecList getInstEmbeddings(const Function &F) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+    InstVecList Result;
+
+    if (F.isDeclaration())
+      return Result;
+
+    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const Instruction &I : instructions(F)) {
+      std::string InstStr;
+      raw_string_ostream(InstStr) << I;
+      Result.push_back({InstStr, Emb->getInstVector(I)});
+    }
+    return Result;
   }
 
   /// Generate embeddings for the entire module
@@ -282,44 +371,31 @@ class IR2VecTool {
 
   /// Generate embeddings for a single function
   void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+    assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
     if (F.isDeclaration()) {
       OS << "Function " << F.getName() << " is a declaration, skipping.\n";
       return;
     }
 
-    // Create embedder for this function
-    assert(Vocab->isValid() && "Vocabulary is not valid");
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for function " << F.getName() << "\n";
-      return;
-    }
-
     OS << "Function: " << F.getName() << "\n";
 
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      Emb->getFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      for (const BasicBlock &BB : F) {
-        OS << BB.getName() << ":";
-        Emb->getBBVector(BB).print(OS);
+    auto printListLevel = [&](const auto& list, const char* suffix_str) {
+      for (const auto& [name, embedding] : list) {
+        OS << name << suffix_str;
+        embedding.print(OS);
       }
-      break;
-    }
-    case InstructionLevel: {
-      for (const BasicBlock &BB : F) {
-        for (const Instruction &I : BB) {
-          I.print(OS);
-          Emb->getInstVector(I).print(OS);
-        }
-      }
-      break;
-    }
+    };
+
+    switch (Level) {
+      case EmbeddingLevel::FunctionLevel:
+        getFunctionEmbedding(F).print(OS);
+        break;
+      case EmbeddingLevel::BasicBlockLevel:
+        printListLevel(getBBEmbeddings(F), ":");
+        break;
+      case EmbeddingLevel::InstructionLevel:
+        printListLevel(getInstEmbeddings(F), "");
+        break;
     }
   }
 };
@@ -423,49 +499,22 @@ class MIR2VecTool {
         << "No machine functions found to initialize vocabulary\n";
     return false;
   }
-
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(const Module &M, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single machine function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
+  
+  /// Get triplets for a single machine function
+  /// Returns TripletResult containing MaxRelation and vector of Triplets
+  TripletResult getTriplets(const MachineFunction &MF) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
 
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return MaxRelation;
+      return Result;
     }
 
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
+
     for (const MachineBasicBlock &MBB : MF) {
       for (const MachineInstr &MI : MBB) {
         // Skip debug instructions
@@ -477,8 +526,7 @@ class MIR2VecTool {
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
-             << '\n';
+          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
           LLVM_DEBUG(dbgs()
                      << Vocab->getStringKey(PrevOpcode) << '\t'
                      << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -489,7 +537,7 @@ class MIR2VecTool {
         for (const MachineOperand &MO : MI.operands()) {
           auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
           unsigned RelationID = MIRArgRelation + ArgIndex;
-          OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
           LLVM_DEBUG({
             std::string OperandStr = Vocab->getStringKey(OperandID);
             dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -501,28 +549,82 @@ class MIR2VecTool {
 
         // Update MaxRelation if there were operands
         if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+          Result.MaxRelation = std::max(Result.MaxRelation, 
+                                       MIRArgRelation + ArgIndex - 1);
 
         PrevOpcode = OpcodeID;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    return Result;
   }
+  
+  /// Get triplets for the entire module
+  /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+  TripletResult getTriplets(const Module &M) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
 
-  /// Generate entity mappings with vocabulary
-  void generateEntityMappings(raw_ostream &OS) const {
+    for (const Function &F : M) {
+      if (F.isDeclaration())
+        continue;
+
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        WithColor::warning(errs(), ToolName)
+            << "No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      TripletResult FuncResult = getTriplets(*MF);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(),
+                            FuncResult.Triplets.begin(),
+                            FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Generate triplets for the module and write to output stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(const Module &M, raw_ostream &OS) const {
+    auto Result = getTriplets(M);
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets) {
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+    }
+  }
+
+  /// Collect entity mappings
+  /// Returns EntityList containing all entity strings
+  EntityList collectEntityMappings() const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "Vocabulary must be initialized for entity mappings.\n";
-      return;
+      return {};
     }
 
     const unsigned EntityCount = Vocab->getCanonicalSize();
-    OS << EntityCount << "\n";
+    EntityList Result;
+    Result.reserve(EntityCount);
+
     for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocab->getStringKey(EntityID));
+
+    return Result;
+  }
+
+  /// Generate entity mappings and write to output stream
+  void generateEntityMappings(raw_ostream &OS) const {
+    auto Entities = collectEntityMappings();
+    if (Entities.empty())
+      return;
+
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
   }
 
   /// Generate embeddings for all machine functions in the module
@@ -547,48 +649,93 @@ class MIR2VecTool {
     }
   }
 
-  /// Generate embeddings for a specific machine function
-  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+  /// Get machine function embedding
+  /// Returns Embedding for the entire machine function
+  ir2vec::Embedding getMFunctionEmbedding(MachineFunction &MF) const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
+      return {};
     }
 
     auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
+    return Emb ? Emb->getMFunctionVector() : ir2vec::Embedding{};
+  }
+
+  /// Get machine basic block embeddings
+  /// Returns BBVecList containing (name, embedding) pairs for all MBBs
+  BBVecList getMBBEmbeddings(MachineFunction &MF) const {
+    BBVecList Result;
+
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return Result;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const MachineBasicBlock &MBB : MF)
+      Result.push_back({MBB.getName().str(), Emb->getMBBVector(MBB)});
+
+    return Result;
+  }
+
+  /// Get machine instruction embeddings
+  /// Returns InstVecList containing (instruction_string, embedding) pairs
+  InstVecList getMInstEmbeddings(MachineFunction &MF) const {
+    InstVecList Result;
+
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+      return Result;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb)
+      return Result;
+
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        std::string InstStr;
+        raw_string_ostream(InstStr) << MI;
+        Result.push_back({InstStr, Emb->getMInstVector(MI)});
+      }
+    }
+
+    return Result;
+  }
+
+  /// Generate embeddings for a specific machine function
+  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+    if (!Vocab) {
+      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
       return;
     }
 
     OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
 
+    auto printListLevel = [&](const auto& list, const char* label, const char* prefix_str, const char* suffix_str) {
+      OS << label << ":\n";
+      for (const auto& [name, embedding] : list) {
+        OS << prefix_str << name << suffix_str;
+        embedding.print(OS);
+      }
+    };
+
     // Generate embeddings based on the specified level
     switch (Level) {
-    case FunctionLevel: {
+    case FunctionLevel:
       OS << "Function vector: ";
-      Emb->getMFunctionVector().print(OS);
+      getMFunctionEmbedding(MF).print(OS);
       break;
-    }
-    case BasicBlockLevel: {
-      OS << "Basic block vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        OS << "MBB " << MBB.getName() << ": ";
-        Emb->getMBBVector(MBB).print(OS);
-      }
+    case BasicBlockLevel:
+      printListLevel(getMBBEmbeddings(MF), "Basic block vectors", "MBB ", ": ");
       break;
-    }
-    case InstructionLevel: {
-      OS << "Instruction vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        for (const MachineInstr &MI : MBB) {
-          OS << MI << " -> ";
-          Emb->getMInstVector(MI).print(OS);
-        }
-      }
+    case InstructionLevel:
+      printListLevel(getMInstEmbeddings(MF), "Instruction vectors", "", " -> ");
       break;
     }
-    }
   }
 
   const MIRVocabulary *getVocabulary() const { return Vocab.get(); }



More information about the llvm-commits mailing list