[llvm-branch-commits] [llvm] [IR2Vec] Add embeddings mode to llvm-ir2vec tool (PR #147844)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jul 11 12:54:38 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/147844
>From bf757c03868bf5e85966440408e41f5343727384 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 9 Jul 2025 22:44:03 +0000
Subject: [PATCH] IR2Vec Tool Enhancements
---
llvm/test/tools/llvm-ir2vec/embeddings.ll | 73 +++++++++
llvm/test/tools/llvm-ir2vec/triplets.ll | 2 +-
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 185 ++++++++++++++++++++--
3 files changed, 248 insertions(+), 12 deletions(-)
create mode 100644 llvm/test/tools/llvm-ir2vec/embeddings.ll
diff --git a/llvm/test/tools/llvm-ir2vec/embeddings.ll b/llvm/test/tools/llvm-ir2vec/embeddings.ll
new file mode 100644
index 0000000000000..d5eed749193ac
--- /dev/null
+++ b/llvm/test/tools/llvm-ir2vec/embeddings.ll
@@ -0,0 +1,73 @@
+; RUN: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT
+; RUN: llvm-ir2vec --mode=embeddings --level=func --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL
+; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC
+; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF
+; RUN: llvm-ir2vec --mode=embeddings --level=bb --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL
+; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT
+; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT
+
+define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; CHECK-DEFAULT: Function: abc
+; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]
+; CHECK-DEFAULT-NEXT: Function: abc_repeat
+; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]
+
+; CHECK-FUNC-LEVEL: Function: abc
+; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]
+; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat
+; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]
+
+; CHECK-FUNC-LEVEL-ABC: Function: abc
+; CHECK-FUNC-LEVEL-NEXT-ABC: [ 878.00 889.00 900.00 ]
+
+; CHECK-FUNC-DEF: Error: Function 'def' not found
+
+; CHECK-BB-LEVEL: Function: abc
+; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]
+; CHECK-BB-LEVEL-NEXT: Function: abc_repeat
+; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]
+
+; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat
+; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00 889.00 900.00 ]
+
+; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 98.00 99.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 97.00 98.00 99.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00 50.00 51.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00 131.00 132.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 94.00 95.00 96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00 41.00 42.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00 2.00 3.00 ]
diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll
index fa5aaa895406f..d1ef5b388e258 100644
--- a/llvm/test/tools/llvm-ir2vec/triplets.ll
+++ b/llvm/test/tools/llvm-ir2vec/triplets.ll
@@ -1,4 +1,4 @@
-; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS
+; RUN: llvm-ir2vec --mode=triplets %s | FileCheck %s -check-prefix=TRIPLETS
define i32 @simple_add(i32 %a, i32 %b) {
entry:
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 35e1c995fa4cc..ab2b734da233e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -9,12 +9,18 @@
/// \file
/// This file implements the IR2Vec embedding generation tool.
///
-/// Currently supports triplet generation for vocabulary training.
-/// Future updates will support embedding generation using trained vocabulary.
+/// This tool provides two main functionalities:
///
-/// Usage: llvm-ir2vec input.bc -o triplets.txt
+/// 1. Triplet Generation Mode (--mode=triplets):
+/// Generates triplets (opcode, type, operands) for vocabulary training.
+/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
///
-/// TODO: Add embedding generation mode with vocabulary support
+/// 2. Embedding Generation Mode (--mode=embeddings):
+/// Generates IR2Vec embeddings using a trained vocabulary.
+/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
+/// --level=func input.bc -o embeddings.txt Levels: --level=inst
+/// (instructions), --level=bb (basic blocks), --level=func (functions)
+/// (See IR2Vec.cpp for more embedding generation options)
///
//===----------------------------------------------------------------------===//
@@ -24,6 +30,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
@@ -34,7 +42,7 @@
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
-using namespace ir2vec;
+using namespace llvm::ir2vec;
#define DEBUG_TYPE "ir2vec"
@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::init("-"),
cl::cat(IR2VecToolCategory));
+enum ToolMode {
+ TripletMode, // Generate triplets for vocabulary training
+ EmbeddingMode // Generate embeddings using trained vocabulary
+};
+
+static cl::opt<ToolMode>
+ Mode("mode", cl::desc("Tool operation mode:"),
+ cl::values(clEnumValN(TripletMode, "triplets",
+ "Generate triplets for vocabulary training"),
+ clEnumValN(EmbeddingMode, "embeddings",
+ "Generate embeddings using trained vocabulary")),
+ cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
+
+static cl::opt<std::string>
+ FunctionName("function", cl::desc("Process specific function only"),
+ cl::value_desc("name"), cl::Optional, cl::init(""),
+ cl::cat(IR2VecToolCategory));
+
+enum EmbeddingLevel {
+ InstructionLevel, // Generate instruction-level embeddings
+ BasicBlockLevel, // Generate basic block-level embeddings
+ FunctionLevel // Generate function-level embeddings
+};
+
+static cl::opt<EmbeddingLevel>
+ Level("level", cl::desc("Embedding generation level (for embedding mode):"),
+ cl::values(clEnumValN(InstructionLevel, "inst",
+ "Generate instruction-level embeddings"),
+ clEnumValN(BasicBlockLevel, "bb",
+ "Generate basic block-level embeddings"),
+ clEnumValN(FunctionLevel, "func",
+ "Generate function-level embeddings")),
+ cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
+
namespace {
-/// Helper class for collecting IR information and generating triplets
+/// Helper class for collecting IR information and generating embeddings
class IR2VecTool {
private:
Module &M;
+ ModuleAnalysisManager MAM;
+ const Vocabulary *Vocab = nullptr;
public:
explicit IR2VecTool(Module &M) : M(M) {}
+ /// Initialize the IR2Vec vocabulary analysis
+ bool initializeVocabulary() {
+ // Register and run the IR2Vec vocabulary analysis
+ // The vocabulary file path is specified via --ir2vec-vocab-path global
+ // option
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+ Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+ return Vocab->isValid();
+ }
+
/// Generate triplets for the entire module
void generateTriplets(raw_ostream &OS) const {
for (const Function &F : M)
@@ -81,6 +136,70 @@ class IR2VecTool {
OS << LocalOutput;
}
+ /// Generate embeddings for the entire module
+ void generateEmbeddings(raw_ostream &OS) const {
+ if (!Vocab->isValid()) {
+ OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n";
+ return;
+ }
+
+ for (const Function &F : M)
+ generateEmbeddings(F, OS);
+ }
+
+ /// Generate embeddings for a single function
+ void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+ if (F.isDeclaration()) {
+ OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+ return;
+ }
+
+ // Create embedder for this function
+ assert(Vocab->isValid() && "Vocabulary is not valid");
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab);
+ if (!Emb) {
+ OS << "Error: Failed to create embedder for function " << F.getName()
+ << "\n";
+ return;
+ }
+
+ OS << "Function: " << F.getName() << "\n";
+
+ // Generate embeddings based on the specified level
+ switch (Level) {
+ case FunctionLevel: {
+ Emb->getFunctionVector().print(OS);
+ break;
+ }
+ case BasicBlockLevel: {
+ const auto &BBVecMap = Emb->getBBVecMap();
+ for (const BasicBlock &BB : F) {
+ auto It = BBVecMap.find(&BB);
+ if (It != BBVecMap.end()) {
+ OS << BB.getName() << ":";
+ It->second.print(OS);
+ }
+ }
+ break;
+ }
+ case InstructionLevel: {
+ const auto &InstMap = Emb->getInstVecMap();
+ for (const BasicBlock &BB : F) {
+ for (const Instruction &I : BB) {
+ auto It = InstMap.find(&I);
+ if (It != InstMap.end()) {
+ I.print(OS);
+ It->second.print(OS);
+ }
+ }
+ }
+ break;
+ }
+ }
+
+ // OS << "\n";
+ }
+
private:
/// Process a single basic block for triplet generation
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
@@ -105,8 +224,42 @@ class IR2VecTool {
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
- Tool.generateTriplets(OS);
+ if (Mode == EmbeddingMode) {
+ // Initialize vocabulary for embedding generation
+ // Note: Requires --ir2vec-vocab-path option to be set
+ if (!Tool.initializeVocabulary())
+ return createStringError(
+ errc::invalid_argument,
+ "Failed to initialize IR2Vec vocabulary. "
+ "Make sure to specify --ir2vec-vocab-path for embedding mode.");
+
+ if (!FunctionName.empty()) {
+ // Process single function
+ if (const Function *F = M.getFunction(FunctionName))
+ Tool.generateEmbeddings(*F, OS);
+ else
+ return createStringError(errc::invalid_argument,
+ "Function '%s' not found",
+ FunctionName.c_str());
+ } else {
+ // Process all functions
+ Tool.generateEmbeddings(OS);
+ }
+ } else {
+ // Triplet generation mode - no vocabulary needed
+ if (!FunctionName.empty())
+ // Process single function
+ if (const Function *F = M.getFunction(FunctionName))
+ Tool.generateTriplets(*F, OS);
+ else
+ return createStringError(errc::invalid_argument,
+ "Function '%s' not found",
+ FunctionName.c_str());
+ else
+ // Process all functions
+ Tool.generateTriplets(OS);
+ }
return Error::success();
}
@@ -117,11 +270,21 @@ int main(int argc, char **argv) {
cl::HideUnrelatedOptions(IR2VecToolCategory);
cl::ParseCommandLineOptions(
argc, argv,
- "IR2Vec - Triplet Generation Tool\n"
- "Generates triplets for vocabulary training from LLVM IR.\n"
- "Future updates will support embedding generation.\n\n"
+ "IR2Vec - Embedding Generation Tool\n"
+ "Generates embeddings for a given LLVM IR and "
+ "supports triplet generation for vocabulary "
+ "training and embedding generation.\n\n"
"Usage:\n"
- " llvm-ir2vec input.bc -o triplets.txt\n");
+ " Triplet mode: llvm-ir2vec --mode=triplets input.bc\n"
+ " Embedding mode: llvm-ir2vec --mode=embeddings "
+ "--ir2vec-vocab-path=vocab.json --level=func input.bc\n"
+ " Levels: --level=inst (instructions), --level=bb (basic blocks), "
+ "--level=func (functions)\n");
+
+ // Validate command line options
+ if (Mode == TripletMode && Level != FunctionLevel) {
+ errs() << "Warning: --level option is ignored in triplet mode\n";
+ }
// Parse the input LLVM IR file
SMDiagnostic Err;
More information about the llvm-branch-commits
mailing list