[llvm] [IR2Vec][llvm-ir2vec] Revamp triplet generation and add entity mapping mode (PR #149214)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 29 07:05:42 PDT 2025


================
@@ -111,29 +128,101 @@ class IR2VecTool {
     // option
     MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
     MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+    // This will throw an error if vocab is not found or invalid
     Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
     return Vocab->isValid();
   }
 
-  /// Generate triplets for the entire module
+  /// Generate triplets for the module
+  /// Output format: MAX_RELATION=N header followed by relationships
   void generateTriplets(raw_ostream &OS) const {
-    for (const Function &F : M)
-      generateTriplets(F, OS);
+    unsigned MaxRelation = NextRelation; // Track maximum relation ID
+    std::string Relationships;
+    raw_string_ostream RelOS(Relationships);
+
+    for (const Function &F : M) {
+      unsigned FuncMaxRelation = generateTriplets(F, RelOS);
+      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
+    }
+
+    RelOS.flush();
+
+    // Write metadata header followed by relationships
+    OS << "MAX_RELATION=" << MaxRelation << '\n';
+    OS << Relationships;
   }
 
   /// Generate triplets for a single function
-  void generateTriplets(const Function &F, raw_ostream &OS) const {
+  /// Returns the maximum relation ID used in this function
+  unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
     if (F.isDeclaration())
-      return;
+      return 0;
+
+    unsigned MaxRelation = 1;
+    unsigned PrevOpcode = 0;
+    bool HasPrevOpcode = false;
+
+    for (const BasicBlock &BB : F) {
+      for (const auto &I : BB.instructionsWithoutDebug()) {
+        unsigned Opcode = Vocabulary::getNumericID(I.getOpcode());
+        unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID());
+
+        // Add "Next" relationship with previous instruction
+        if (HasPrevOpcode) {
+          OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+          LLVM_DEBUG(dbgs()
+                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                     << "Next\n");
+        }
 
-    std::string LocalOutput;
-    raw_string_ostream LocalOS(LocalOutput);
+        // Add "Type" relationship
+        OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+        LLVM_DEBUG(
+            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                   << '\t' << "Type\n");
+
+        // Add "Arg" relationships
+        unsigned ArgIndex = 0;
+        for (const Use &U : I.operands()) {
+          unsigned OperandID = Vocabulary::getNumericID(U.get());
+          unsigned RelationID = ArgRelation + ArgIndex;
+          OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+
+          LLVM_DEBUG({
+            StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+                Vocabulary::getOperandKind(U.get()));
+            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+          });
+
+          ArgIndex++;
----------------
mtrofin wrote:

++ArgIndex

https://github.com/llvm/llvm-project/pull/149214


More information about the llvm-commits mailing list