[llvm] [IR2Vec][llvm-ir2vec] Revamp triplet generation and add entity mapping mode (PR #149214)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 29 07:05:42 PDT 2025
================
@@ -111,29 +128,101 @@ class IR2VecTool {
// option
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+ // This will throw an error if vocab is not found or invalid
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
return Vocab->isValid();
}
- /// Generate triplets for the entire module
+ /// Generate triplets for the module
+ /// Output format: MAX_RELATION=N header followed by relationships
void generateTriplets(raw_ostream &OS) const {
- for (const Function &F : M)
- generateTriplets(F, OS);
+ unsigned MaxRelation = NextRelation; // Track maximum relation ID
+ std::string Relationships;
+ raw_string_ostream RelOS(Relationships);
+
+ for (const Function &F : M) {
+ unsigned FuncMaxRelation = generateTriplets(F, RelOS);
+ MaxRelation = std::max(MaxRelation, FuncMaxRelation);
+ }
+
+ RelOS.flush();
+
+ // Write metadata header followed by relationships
+ OS << "MAX_RELATION=" << MaxRelation << '\n';
+ OS << Relationships;
}
/// Generate triplets for a single function
- void generateTriplets(const Function &F, raw_ostream &OS) const {
+ /// Returns the maximum relation ID used in this function
+ unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration())
- return;
+ return 0;
+
+ unsigned MaxRelation = 1;
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
+ for (const BasicBlock &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = Vocabulary::getNumericID(I.getOpcode());
+ unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
+ LLVM_DEBUG(dbgs()
+ << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << "Next\n");
+ }
- std::string LocalOutput;
- raw_string_ostream LocalOS(LocalOutput);
+ // Add "Type" relationship
+ OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
+ LLVM_DEBUG(
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+ << '\t' << "Type\n");
+
+ // Add "Arg" relationships
+ unsigned ArgIndex = 0;
+ for (const Use &U : I.operands()) {
+ unsigned OperandID = Vocabulary::getNumericID(U.get());
+ unsigned RelationID = ArgRelation + ArgIndex;
+ OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
+
+ LLVM_DEBUG({
+ StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+ Vocabulary::getOperandKind(U.get()));
+ dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+ << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ArgIndex++;
----------------
mtrofin wrote:
++ArgIndex
https://github.com/llvm/llvm-project/pull/149214
More information about the llvm-commits
mailing list