[llvm] Refactoring llvm-ir2vec.cpp for better separation of concerns in the Tooling classes (PR #170078)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 01:09:07 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/170078

>From 834a74d72331e4dd97da84ad703be6b165cb0ae9 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Fri, 28 Nov 2025 19:27:28 +0530
Subject: [PATCH] Refactoring llvm-ir2vec.cpp for better separation of concerns
 in the Tooling classes

---
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 230 ++++++++++++++++---------
 1 file changed, 146 insertions(+), 84 deletions(-)

diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7b8d3f093a3d1..4cb8742c7654d 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -58,6 +58,7 @@
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -151,6 +152,24 @@ static cl::opt<EmbeddingLevel>
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
           cl::cat(CommonCategory));
 
+/// Represents a single knowledge graph triplet (Head, Relation, Tail)
+/// where indices reference entities in an EntityList
+struct Triplet {
+  unsigned Head = 0;     ///< Index of the head entity in the entity list
+  unsigned Tail = 0;     ///< Index of the tail entity in the entity list
+  unsigned Relation = 0; ///< Relation type (see RelationType enum)
+};
+
+/// Result structure containing all generated triplets and metadata
+struct TripletResult {
+  unsigned MaxRelation =
+      0; ///< Highest relation index used (for ArgRelation + N)
+  std::vector<Triplet> Triplets; ///< Collection of all generated triplets
+};
+
+/// Entity mappings: [entity_name]
+using EntityList = std::vector<std::string>;
+
 namespace ir2vec {
 
 /// Relation types for triplet generation
@@ -182,32 +201,18 @@ class IR2VecTool {
     return Vocab->isValid();
   }
 
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(raw_ostream &OS) const {
-    unsigned MaxRelation = NextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M) {
-      unsigned FuncMaxRelation = generateTriplets(F, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
   /// Generate triplets for a single function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
+  /// Returns a TripletResult with:
+  ///   - Triplets: vector of all (subject, object, relation) tuples
+  ///   - MaxRelation: highest Arg relation ID used, or NextRelation if none
+  TripletResult generateTriplets(const Function &F) const {
     if (F.isDeclaration())
-      return 0;
+      return {};
 
-    unsigned MaxRelation = 1;
+    TripletResult Result;
+    Result.MaxRelation = 0;
+
+    unsigned MaxRelation = NextRelation;
     unsigned PrevOpcode = 0;
     bool HasPrevOpcode = false;
 
@@ -218,7 +223,7 @@ class IR2VecTool {
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+          Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
           LLVM_DEBUG(dbgs()
                      << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
                      << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
@@ -226,7 +231,7 @@ class IR2VecTool {
         }
 
         // Add "Type" relationship
-        OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+        Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
         LLVM_DEBUG(
             dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
                    << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
@@ -237,7 +242,7 @@ class IR2VecTool {
         for (const Use &U : I.operands()) {
           unsigned OperandID = Vocabulary::getIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
-          OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({Opcode, OperandID, RelationID});
 
           LLVM_DEBUG({
             StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
@@ -249,23 +254,57 @@ class IR2VecTool {
           ++ArgIndex;
         }
         // Only update MaxRelation if there were operands
-        if (ArgIndex > 0) {
+        if (ArgIndex > 0)
           MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-        }
         PrevOpcode = Opcode;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    Result.MaxRelation = MaxRelation;
+    return Result;
   }
 
-  /// Dump entity ID to string mappings
-  static void generateEntityMappings(raw_ostream &OS) {
+  /// Get triplets for the entire module
+  TripletResult generateTriplets() const {
+    TripletResult Result;
+    Result.MaxRelation = NextRelation;
+
+    for (const Function &F : M.getFunctionDefs()) {
+      TripletResult FuncResult = generateTriplets(F);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Collect triplets for the module and dump output to stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(raw_ostream &OS) const {
+    auto Result = generateTriplets();
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets)
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  }
+
+  /// Generate entity mappings for the entire vocabulary
+  /// Returns EntityList containing all entity strings
+  static EntityList collectEntityMappings() {
     auto EntityLen = Vocabulary::getCanonicalSize();
-    OS << EntityLen << "\n";
+    EntityList Result;
     for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocabulary::getStringKey(EntityID).str());
+    return Result;
+  }
+
+  /// Dump entity ID to string mappings
+  static void generateEntityMappings(raw_ostream &OS) {
+    auto Entities = collectEntityMappings();
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
   }
 
   /// Generate embeddings for the entire module
@@ -276,19 +315,23 @@ class IR2VecTool {
       return;
     }
 
-    for (const Function &F : M)
+    for (const Function &F : M.getFunctionDefs())
       generateEmbeddings(F, OS);
   }
 
   /// Generate embeddings for a single function
   void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+    if (!Vocab || !Vocab->isValid()) {
+      WithColor::error(errs(), ToolName)
+          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
+      return;
+    }
     if (F.isDeclaration()) {
       OS << "Function " << F.getName() << " is a declaration, skipping.\n";
       return;
     }
 
     // Create embedder for this function
-    assert(Vocab->isValid() && "Vocabulary is not valid");
     auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
     if (!Emb) {
       WithColor::error(errs(), ToolName)
@@ -312,11 +355,9 @@ class IR2VecTool {
       break;
     }
     case InstructionLevel: {
-      for (const BasicBlock &BB : F) {
-        for (const Instruction &I : BB) {
-          I.print(OS);
-          Emb->getInstVector(I).print(OS);
-        }
+      for (const Instruction &I : instructions(F)) {
+        OS << I;
+        Emb->getInstVector(I).print(OS);
       }
       break;
     }
@@ -324,6 +365,7 @@ class IR2VecTool {
   }
 };
 
+/// Process the module and generate output based on selected subcommand
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
 
@@ -422,46 +464,20 @@ class MIR2VecTool {
     return false;
   }
 
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(const Module &M, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M.getFunctionDefs()) {
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single machine function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
+  /// Get triplets for a single machine function
+  /// Returns TripletResult containing MaxRelation and vector of Triplets
+  TripletResult generateTriplets(const MachineFunction &MF) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
 
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return MaxRelation;
+      return Result;
     }
 
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
     for (const MachineBasicBlock &MBB : MF) {
       for (const MachineInstr &MI : MBB) {
         // Skip debug instructions
@@ -473,8 +489,7 @@ class MIR2VecTool {
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
-             << '\n';
+          Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
           LLVM_DEBUG(dbgs()
                      << Vocab->getStringKey(PrevOpcode) << '\t'
                      << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
@@ -485,7 +500,7 @@ class MIR2VecTool {
         for (const MachineOperand &MO : MI.operands()) {
           auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
           unsigned RelationID = MIRArgRelation + ArgIndex;
-          OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+          Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
           LLVM_DEBUG({
             std::string OperandStr = Vocab->getStringKey(OperandID);
             dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
@@ -497,28 +512,74 @@ class MIR2VecTool {
 
         // Update MaxRelation if there were operands
         if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+          Result.MaxRelation =
+              std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
 
         PrevOpcode = OpcodeID;
         HasPrevOpcode = true;
       }
     }
 
-    return MaxRelation;
+    return Result;
   }
 
-  /// Generate entity mappings with vocabulary
-  void generateEntityMappings(raw_ostream &OS) const {
+  /// Get triplets for the entire module
+  /// Returns TripletResult containing aggregated MaxRelation and all Triplets
+  TripletResult generateTriplets(const Module &M) const {
+    TripletResult Result;
+    Result.MaxRelation = MIRNextRelation;
+
+    for (const Function &F : M.getFunctionDefs()) {
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        WithColor::warning(errs(), ToolName)
+            << "No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      TripletResult FuncResult = generateTriplets(*MF);
+      Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+      Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
+                             FuncResult.Triplets.end());
+    }
+
+    return Result;
+  }
+
+  /// Generate triplets for the module and write to output stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(const Module &M, raw_ostream &OS) const {
+    auto Result = generateTriplets(M);
+    OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+    for (const auto &T : Result.Triplets)
+      OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  }
+
+  /// Generate entity mappings for the entire vocabulary
+  EntityList collectEntityMappings() const {
     if (!Vocab) {
       WithColor::error(errs(), ToolName)
           << "Vocabulary must be initialized for entity mappings.\n";
-      return;
+      return {};
     }
 
     const unsigned EntityCount = Vocab->getCanonicalSize();
-    OS << EntityCount << "\n";
+    EntityList Result;
     for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+      Result.push_back(Vocab->getStringKey(EntityID));
+
+    return Result;
+  }
+
+  /// Generate entity mappings and write to output stream
+  void generateEntityMappings(raw_ostream &OS) const {
+    auto Entities = collectEntityMappings();
+    if (Entities.empty())
+      return;
+
+    OS << Entities.size() << "\n";
+    for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+      OS << Entities[EntityID] << '\t' << EntityID << '\n';
   }
 
   /// Generate embeddings for all machine functions in the module
@@ -585,6 +646,7 @@ class MIR2VecTool {
     }
   }
 
+  /// Get the MIR vocabulary instance
   const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
 };
 



More information about the llvm-commits mailing list