[llvm] Refactoring llvm-ir2vec.cpp into separate components for handling IR2Vec, MIR2Vec calls (PR #167656)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 17 01:14:53 PST 2025


https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/167656

>From 3c432003afe70cb322d8dc4caf5ed54188da86db Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Mon, 10 Nov 2025 16:31:12 +0530
Subject: [PATCH] Work Commit - Refactoring llvm-ir2vec.cpp into separate
 components for handling IR2Vec, MIR2Vec calls

commit-id:2c35cc07
---
 llvm/tools/CMakeLists.txt                     |   1 +
 llvm/tools/llvm-ir2vec/CMakeLists.txt         |   8 +-
 .../llvm-ir2vec/include/EmbeddingCommon.h     |  81 ++
 llvm/tools/llvm-ir2vec/include/IR2VecTool.h   |  93 +++
 llvm/tools/llvm-ir2vec/include/MIR2VecTool.h  | 153 ++++
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        | 765 ------------------
 llvm/tools/llvm-ir2vec/src/IR2VecTool.cpp     | 302 +++++++
 llvm/tools/llvm-ir2vec/src/MIR2VecTool.cpp    | 489 +++++++++++
 llvm/tools/llvm-ir2vec/src/llvm-ir2vec.cpp    | 359 ++++++++
 9 files changed, 1485 insertions(+), 766 deletions(-)
 create mode 100644 llvm/tools/llvm-ir2vec/include/EmbeddingCommon.h
 create mode 100644 llvm/tools/llvm-ir2vec/include/IR2VecTool.h
 create mode 100644 llvm/tools/llvm-ir2vec/include/MIR2VecTool.h
 delete mode 100644 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
 create mode 100644 llvm/tools/llvm-ir2vec/src/IR2VecTool.cpp
 create mode 100644 llvm/tools/llvm-ir2vec/src/MIR2VecTool.cpp
 create mode 100644 llvm/tools/llvm-ir2vec/src/llvm-ir2vec.cpp

diff --git a/llvm/tools/CMakeLists.txt b/llvm/tools/CMakeLists.txt
index 729797aa43f0b..8e41448be8e6e 100644
--- a/llvm/tools/CMakeLists.txt
+++ b/llvm/tools/CMakeLists.txt
@@ -39,6 +39,7 @@ add_llvm_tool_subdirectory(llvm-config)
 add_llvm_tool_subdirectory(llvm-ctxprof-util)
 add_llvm_tool_subdirectory(llvm-lto)
 add_llvm_tool_subdirectory(llvm-profdata)
+add_llvm_tool_subdirectory(llvm-ir2vec)
 
 # Projects supported via LLVM_EXTERNAL_*_SOURCE_DIR need to be explicitly
 # specified.
diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index 2bb6686392907..6b8b2f7ec8112 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS
   # Core LLVM components for IR processing
   Analysis
   Core
+  Demangle
   IRReader
   Support
   
@@ -17,8 +18,13 @@ set(LLVM_LINK_COMPONENTS
   TargetParser
   )
 
+# Add include directory
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+
 add_llvm_tool(llvm-ir2vec
-  llvm-ir2vec.cpp
+  src/llvm-ir2vec.cpp
+  src/IR2VecTool.cpp
+  src/MIR2VecTool.cpp
   
   DEPENDS
   intrinsics_gen
diff --git a/llvm/tools/llvm-ir2vec/include/EmbeddingCommon.h b/llvm/tools/llvm-ir2vec/include/EmbeddingCommon.h
new file mode 100644
index 0000000000000..51f77b60aa009
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/include/EmbeddingCommon.h
@@ -0,0 +1,81 @@
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_EMBEDDINGCOMMON_H
+#define LLVM_TOOLS_LLVM_IR2VEC_EMBEDDINGCOMMON_H
+
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Demangle/Demangle.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include <cxxabi.h>
+#include <cstdlib>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace llvm {
+
+static inline std::string getDemagledName(const Function *F) {
+  auto FunctionName = F->getName().str();
+  std::size_t Sz = 17;
+  int Status;
+  char *const ReadableName =
+      __cxxabiv1::__cxa_demangle(FunctionName.c_str(), 0, &Sz, &Status);
+  auto DemangledName =
+      Status == 0 ? std::string(ReadableName) : std::string(FunctionName);
+  free(ReadableName);
+  return DemangledName;
+}
+
+static inline std::string getActualName(const Function *F) {
+  auto FunctionName = F->getName().str();
+  auto DemangledName = getDemagledName(F);
+  size_t Size = 1;
+  char *Buf = static_cast<char *>(std::malloc(Size));
+  const char *Mangled = FunctionName.c_str();
+  char *BaseName;
+  llvm::ItaniumPartialDemangler Mangler;
+  if (Mangler.partialDemangle(Mangled)) {
+    BaseName = &DemangledName[0];
+  } else {
+    BaseName = Mangler.getFunctionBaseName(Buf, &Size);
+  }
+  free(Buf);
+  return BaseName ? std::string(BaseName) : std::string();
+}
+ 
+using Embedding = ir2vec::Embedding;
+
+/// Embedding generation level
+enum EmbeddingLevel {
+  InstructionLevel, ///< Generate instruction-level embeddings
+  BasicBlockLevel,  ///< Generate basic block-level embeddings
+  FunctionLevel     ///< Generate function-level embeddings
+};
+
+/// Triplet for vocabulary training (IR2Vec/MIR2Vec)
+struct Triplet {
+  unsigned Head;
+  unsigned Tail;
+  unsigned Relation;
+};
+
+/// Result of triplet generation
+struct TripletResult {
+  unsigned MaxRelation;
+  std::vector<Triplet> Triplets;
+};
+
+/// Entity mappings: entity_id -> entity_name
+using EntityMap = std::vector<std::string>;
+
+/// Basic block embeddings: bb_name -> Embedding
+using BBVecList = std::vector<std::pair<std::string, Embedding>>;
+
+/// Instruction embeddings: instruction_string -> Embedding
+using InstVecList = std::vector<std::pair<std::string, Embedding>>;
+
+/// Function embeddings: demangled_name -> (actual_name, Embedding)
+using FuncVecMap = std::unordered_map<std::string, std::pair<std::string, Embedding>>;
+
+} // namespace llvm
+
+#endif
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/include/IR2VecTool.h b/llvm/tools/llvm-ir2vec/include/IR2VecTool.h
new file mode 100644
index 0000000000000..039f1cfe879f6
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/include/IR2VecTool.h
@@ -0,0 +1,93 @@
+//===- IR2VecTool.h - IR2Vec Tool Interface ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_IR2VECTOOL_H
+#define LLVM_TOOLS_LLVM_IR2VEC_IR2VECTOOL_H
+
+#include "EmbeddingCommon.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/PassManager.h"
+#include <string>
+#include <vector>
+
+namespace llvm {
+namespace ir2vec {
+/// Relation types for triplet generation
+enum RelationType {
+  TypeRelation = 0,
+  NextRelation = 1,
+  ArgRelation = 2
+};
+
+/// Core IR2Vec tool
+class IR2VecTool {
+private:
+  Module &M;
+  ModuleAnalysisManager MAM;
+  const Vocabulary *Vocab = nullptr;
+
+public:
+  explicit IR2VecTool(Module &M) : M(M), Vocab(nullptr) {}
+
+  /// Initialize vocabulary
+  bool initializeVocabulary();
+
+  /// Generate triplets for module and write to stream
+  void generateTriplets(raw_ostream &OS) const;
+
+  /// Generate embeddings for module and write to stream
+  void generateEmbeddings(raw_ostream &OS) const;
+
+  /// Generate embeddings for single function and write to stream
+  void generateEmbeddings(const Function &F, raw_ostream &OS) const;
+
+  /// Get entity mappings
+  static EntityMap getEntityMappings();
+
+  /// Generate entity mappings (static - no module needed)
+  static void generateEntityMappings(raw_ostream &OS);
+
+  // Data structure methods for Python bindings
+
+  /// Get triplets
+  TripletResult getTriplets() const;
+
+  /// Get triplets for function
+  TripletResult getTriplets(const Function &F) const;
+
+  /// Get single function embedding
+  std::pair<std::string, std::pair<std::string, Embedding>>
+  getFunctionEmbedding(const Function &F) const;
+
+  /// Get function embeddings
+  FuncVecMap getFunctionEmbeddings() const;
+
+  /// Get BB embeddings for a specific function
+  BBVecList getBBEmbeddings(const Function &F) const;
+
+  /// Get BB embeddings
+  BBVecList getBBEmbeddings() const;
+
+  /// Get instruction embeddings for a specific function
+  InstVecList getInstEmbeddings(const Function &F) const;
+
+  /// Get instruction embeddings
+  InstVecList getInstEmbeddings() const;
+
+  /// Check if vocabulary is valid
+  bool isVocabularyValid() const { return Vocab && Vocab->isValid(); }
+
+  Module &getModule() { return M; }
+};
+
+} // namespace ir2vec
+} // namespace llvm
+
+#endif
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/include/MIR2VecTool.h b/llvm/tools/llvm-ir2vec/include/MIR2VecTool.h
new file mode 100644
index 0000000000000..17e9e8556b19b
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/include/MIR2VecTool.h
@@ -0,0 +1,153 @@
+//===- MIR2VecTool.h - MIR2Vec Tool Interface -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file declares the MIR2Vec tool interface for Machine IR embedding
+/// generation, triplet generation, and entity mapping operations.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_IR2VEC_MIR2VECTOOL_H
+#define LLVM_TOOLS_LLVM_IR2VEC_MIR2VECTOOL_H
+
+#include "EmbeddingCommon.h"
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Target/TargetMachine.h"
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace llvm {
+namespace mir2vec {
+
+/// Relation types for MIR2Vec triplet generation
+enum MIRRelationType {
+  MIRNextRelation = 0, ///< Sequential instruction relationship
+  MIRArgRelation = 1   ///< Instruction to operand relationship (ArgRelation + N)
+};
+
+/// Context for MIR parsing and processing
+/// Manages lifetime of MIR parser, target machine, and module info
+struct MIRContext {
+  LLVMContext Context;
+  std::unique_ptr<MIRParser> Parser;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<TargetMachine> TM;
+  std::unique_ptr<MachineModuleInfo> MMI;
+
+  MIRContext() = default;
+
+  MIRContext(const MIRContext &) = delete;
+  MIRContext &operator=(const MIRContext &) = delete;
+
+  MIRContext(MIRContext &&) = default;
+  MIRContext &operator=(MIRContext &&) = default;
+};
+
+/// Core MIR2Vec tool for embedding generation
+class MIR2VecTool {
+private:
+  MachineModuleInfo &MMI;
+  std::unique_ptr<MIRVocabulary> Vocab;
+
+  /// Generate triplets for a single machine function (internal helper)
+  /// Returns the maximum relation ID used in this function
+  unsigned generateTripletsForMF(const MachineFunction &MF,
+                                 std::vector<Triplet> &Triplets) const;
+
+public:
+  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+  /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
+  /// This loads a fully trained vocabulary with embeddings.
+  bool initializeVocabulary(const Module &M);
+
+  /// Initialize vocabulary with layout information only.
+  /// This creates a minimal vocabulary with correct layout but no actual
+  /// embeddings. Sufficient for generating training data and entity mappings.
+  ///
+  /// Note: Requires target-specific information from the first machine function
+  /// to determine the vocabulary layout (number of opcodes, register classes).
+  bool initializeVocabularyForLayout(const Module &M);
+
+  // ========================================================================
+  // Data structure methods (for Python bindings)
+  // ========================================================================
+
+  /// Get triplets for the entire module
+  TripletResult getTriplets(const Module &M) const;
+
+  /// Get triplets for a single machine function
+  TripletResult getTriplets(const MachineFunction &MF) const;
+
+  /// Get entity mappings
+  EntityMap getEntityMappings() const;
+
+  /// Get function-level embeddings for all functions in the module
+  FuncVecMap getFunctionEmbeddings(const Module &M) const;
+
+  /// Get function-level embedding for a single machine function
+  std::pair<std::string, Embedding>
+  getFunctionEmbedding(MachineFunction &MF) const;
+
+  /// Get basic block-level embeddings for a machine function
+  BBVecList getMBBEmbeddings(MachineFunction &MF) const;
+
+  /// Get instruction-level embeddings for a machine function
+  InstVecList getMInstEmbeddings(MachineFunction &MF) const;
+
+  // ========================================================================
+  // Stream output methods (for CLI tool)
+  // ========================================================================
+
+  /// Generate triplets for the module and write to stream
+  /// Output format: MAX_RELATION=N header followed by relationships
+  void generateTriplets(const Module &M, raw_ostream &OS) const;
+
+  /// Generate triplets for a single machine function and write to stream
+  void generateTriplets(const MachineFunction &MF, raw_ostream &OS) const;
+
+  /// Generate entity mappings and write to stream
+  void generateEntityMappings(raw_ostream &OS) const;
+
+  /// Generate embeddings for all machine functions in the module
+  void generateEmbeddings(const Module &M, raw_ostream &OS) const;
+
+  /// Generate embeddings for a specific machine function
+  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const;
+
+  /// Get the vocabulary (for testing/debugging)
+  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+
+  /// Get the MachineModuleInfo
+  MachineModuleInfo &getMachineModuleInfo() { return MMI; }
+};
+
+// ========================================================================
+// MIR Context Setup Functions
+// ========================================================================
+
+/// Initialize target backends (call once at startup)
+void initializeTargets();
+
+/// Setup MIR context from input file
+/// This parses the MIR file, creates target machine, and parses machine functions
+Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx);
+
+} // namespace mir2vec
+} // namespace llvm
+
+#endif // LLVM_TOOLS_LLVM_IR2VEC_MIR2VECTOOL_H
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
deleted file mode 100644
index 7402782bfd404..0000000000000
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ /dev/null
@@ -1,765 +0,0 @@
-//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
-///
-/// This tool supports two modes:
-/// - LLVM IR mode (-mode=llvm): Process LLVM IR
-/// - Machine IR mode (-mode=mir): Process Machine IR
-///
-/// Available subcommands:
-///
-/// 1. Triplet Generation (triplets):
-///    Generates numeric triplets (head, tail, relation) for vocabulary
-///    training. Output format: MAX_RELATION=N header followed by
-///    head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
-///
-///    For LLVM IR:
-///      llvm-ir2vec triplets input.bc -o train2id.txt
-///
-///    For Machine IR:
-///      llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
-///
-/// 2. Entity Mappings (entities):
-///    Generates entity mappings for vocabulary training.
-///    Output format: <total_entities> header followed by entity\tid lines.
-///
-///    For LLVM IR:
-///      llvm-ir2vec entities input.bc -o entity2id.txt
-///
-///    For Machine IR:
-///      llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
-///
-/// 3. Embedding Generation (embeddings):
-///    Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
-///
-///    For LLVM IR:
-///      llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
-///        --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
-///      Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
-///
-///    For Machine IR:
-///      llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
-///        --level=<level> input.mir -o embeddings.txt
-///
-///    Levels: --level=inst (instructions), --level=bb (basic blocks),
-///    --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
-///    generation options)
-///
-//===----------------------------------------------------------------------===//
-
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/Analysis/IR2Vec.h"
-#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/Module.h"
-#include "llvm/IR/PassInstrumentation.h"
-#include "llvm/IR/PassManager.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IRReader/IRReader.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/Errc.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-
-#include "llvm/CodeGen/CommandFlags.h"
-#include "llvm/CodeGen/MIR2Vec.h"
-#include "llvm/CodeGen/MIRParser/MIRParser.h"
-#include "llvm/CodeGen/MachineFunction.h"
-#include "llvm/CodeGen/MachineModuleInfo.h"
-#include "llvm/CodeGen/TargetInstrInfo.h"
-#include "llvm/CodeGen/TargetRegisterInfo.h"
-#include "llvm/MC/TargetRegistry.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/WithColor.h"
-#include "llvm/Target/TargetMachine.h"
-#include "llvm/TargetParser/Host.h"
-
-#define DEBUG_TYPE "ir2vec"
-
-namespace llvm {
-
-static const char *ToolName = "llvm-ir2vec";
-
-// Common option category for options shared between IR2Vec and MIR2Vec
-static cl::OptionCategory CommonCategory("Common Options",
-                                         "Options applicable to both IR2Vec "
-                                         "and MIR2Vec modes");
-
-enum IRKind {
-  LLVMIR = 0, ///< LLVM IR
-  MIR         ///< Machine IR
-};
-
-static cl::opt<IRKind>
-    IRMode("mode", cl::desc("Tool operation mode:"),
-           cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
-                      clEnumValN(MIR, "mir", "Process Machine IR")),
-           cl::init(LLVMIR), cl::cat(CommonCategory));
-
-// Subcommands
-static cl::SubCommand
-    TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
-static cl::SubCommand
-    EntitiesSubCmd("entities",
-                   "Generate entity mappings for vocabulary training");
-static cl::SubCommand
-    EmbeddingsSubCmd("embeddings",
-                     "Generate embeddings using trained vocabulary");
-
-// Common options
-static cl::opt<std::string> InputFilename(
-    cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
-    cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
-    cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
-
-static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
-                                           cl::value_desc("filename"),
-                                           cl::init("-"),
-                                           cl::cat(CommonCategory));
-
-// Embedding-specific options
-static cl::opt<std::string>
-    FunctionName("function", cl::desc("Process specific function only"),
-                 cl::value_desc("name"), cl::Optional, cl::init(""),
-                 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
-
-enum EmbeddingLevel {
-  InstructionLevel, // Generate instruction-level embeddings
-  BasicBlockLevel,  // Generate basic block-level embeddings
-  FunctionLevel     // Generate function-level embeddings
-};
-
-static cl::opt<EmbeddingLevel>
-    Level("level", cl::desc("Embedding generation level:"),
-          cl::values(clEnumValN(InstructionLevel, "inst",
-                                "Generate instruction-level embeddings"),
-                     clEnumValN(BasicBlockLevel, "bb",
-                                "Generate basic block-level embeddings"),
-                     clEnumValN(FunctionLevel, "func",
-                                "Generate function-level embeddings")),
-          cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
-          cl::cat(CommonCategory));
-
-namespace ir2vec {
-
-/// Relation types for triplet generation
-enum RelationType {
-  TypeRelation = 0, ///< Instruction to type relationship
-  NextRelation = 1, ///< Sequential instruction relationship
-  ArgRelation = 2   ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for collecting IR triplets and generating embeddings
-class IR2VecTool {
-private:
-  Module &M;
-  ModuleAnalysisManager MAM;
-  const Vocabulary *Vocab = nullptr;
-
-public:
-  explicit IR2VecTool(Module &M) : M(M) {}
-
-  /// Initialize the IR2Vec vocabulary analysis
-  bool initializeVocabulary() {
-    // Register and run the IR2Vec vocabulary analysis
-    // The vocabulary file path is specified via --ir2vec-vocab-path global
-    // option
-    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
-    MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
-    // This will throw an error if vocab is not found or invalid
-    Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
-    return Vocab->isValid();
-  }
-
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(raw_ostream &OS) const {
-    unsigned MaxRelation = NextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M) {
-      unsigned FuncMaxRelation = generateTriplets(F, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
-    if (F.isDeclaration())
-      return 0;
-
-    unsigned MaxRelation = 1;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-
-    for (const BasicBlock &BB : F) {
-      for (const auto &I : BB.instructionsWithoutDebug()) {
-        unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
-        unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
-          LLVM_DEBUG(dbgs()
-                     << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
-                     << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                     << "Next\n");
-        }
-
-        // Add "Type" relationship
-        OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
-        LLVM_DEBUG(
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
-                   << '\t' << "Type\n");
-
-        // Add "Arg" relationships
-        unsigned ArgIndex = 0;
-        for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getIndex(*U.get());
-          unsigned RelationID = ArgRelation + ArgIndex;
-          OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
-
-          LLVM_DEBUG({
-            StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
-                Vocabulary::getOperandKind(U.get()));
-            dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
-                   << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
-        // Only update MaxRelation if there were operands
-        if (ArgIndex > 0) {
-          MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
-        }
-        PrevOpcode = Opcode;
-        HasPrevOpcode = true;
-      }
-    }
-
-    return MaxRelation;
-  }
-
-  /// Dump entity ID to string mappings
-  static void generateEntityMappings(raw_ostream &OS) {
-    auto EntityLen = Vocabulary::getCanonicalSize();
-    OS << EntityLen << "\n";
-    for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
-      OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
-  }
-
-  /// Generate embeddings for the entire module
-  void generateEmbeddings(raw_ostream &OS) const {
-    if (!Vocab->isValid()) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary is not valid. IR2VecTool not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M)
-      generateEmbeddings(F, OS);
-  }
-
-  /// Generate embeddings for a single function
-  void generateEmbeddings(const Function &F, raw_ostream &OS) const {
-    if (F.isDeclaration()) {
-      OS << "Function " << F.getName() << " is a declaration, skipping.\n";
-      return;
-    }
-
-    // Create embedder for this function
-    assert(Vocab->isValid() && "Vocabulary is not valid");
-    auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for function " << F.getName() << "\n";
-      return;
-    }
-
-    OS << "Function: " << F.getName() << "\n";
-
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      Emb->getFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      for (const BasicBlock &BB : F) {
-        OS << BB.getName() << ":";
-        Emb->getBBVector(BB).print(OS);
-      }
-      break;
-    }
-    case InstructionLevel: {
-      for (const BasicBlock &BB : F) {
-        for (const Instruction &I : BB) {
-          I.print(OS);
-          Emb->getInstVector(I).print(OS);
-        }
-      }
-      break;
-    }
-    }
-  }
-};
-
-Error processModule(Module &M, raw_ostream &OS) {
-  IR2VecTool Tool(M);
-
-  if (EmbeddingsSubCmd) {
-    // Initialize vocabulary for embedding generation
-    // Note: Requires --ir2vec-vocab-path option to be set
-    auto VocabStatus = Tool.initializeVocabulary();
-    assert(VocabStatus && "Failed to initialize IR2Vec vocabulary");
-    (void)VocabStatus;
-
-    if (!FunctionName.empty()) {
-      // Process single function
-      if (const Function *F = M.getFunction(FunctionName))
-        Tool.generateEmbeddings(*F, OS);
-      else
-        return createStringError(errc::invalid_argument,
-                                 "Function '%s' not found",
-                                 FunctionName.c_str());
-    } else {
-      // Process all functions
-      Tool.generateEmbeddings(OS);
-    }
-  } else {
-    // Both triplets and entities use triplet generation
-    Tool.generateTriplets(OS);
-  }
-  return Error::success();
-}
-} // namespace ir2vec
-
-namespace mir2vec {
-
-/// Relation types for MIR2Vec triplet generation
-enum MIRRelationType {
-  MIRNextRelation = 0, ///< Sequential instruction relationship
-  MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for MIR2Vec embedding generation
-class MIR2VecTool {
-private:
-  MachineModuleInfo &MMI;
-  std::unique_ptr<MIRVocabulary> Vocab;
-
-public:
-  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
-  /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
-  bool initializeVocabulary(const Module &M) {
-    MIR2VecVocabProvider Provider(MMI);
-    auto VocabOrErr = Provider.getVocabulary(M);
-    if (!VocabOrErr) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to load MIR2Vec vocabulary - "
-          << toString(VocabOrErr.takeError()) << "\n";
-      return false;
-    }
-    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-    return true;
-  }
-
-  /// Initialize vocabulary with layout information only.
-  /// This creates a minimal vocabulary with correct layout but no actual
-  /// embeddings. Sufficient for generating training data and entity mappings.
-  ///
-  /// Note: Requires target-specific information from the first machine function
-  /// to determine the vocabulary layout (number of opcodes, register classes).
-  ///
-  /// FIXME: Use --target option to get target info directly, avoiding the need
-  /// to parse machine functions for pre-training operations.
-  bool initializeVocabularyForLayout(const Module &M) {
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF)
-        continue;
-
-      const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
-      const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
-      const MachineRegisterInfo &MRI = MF->getRegInfo();
-
-      auto VocabOrErr =
-          MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
-      if (!VocabOrErr) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to create dummy vocabulary - "
-            << toString(VocabOrErr.takeError()) << "\n";
-        return false;
-      }
-      Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
-      return true;
-    }
-
-    WithColor::error(errs(), ToolName)
-        << "No machine functions found to initialize vocabulary\n";
-    return false;
-  }
-
-  /// Generate triplets for the module
-  /// Output format: MAX_RELATION=N header followed by relationships
-  void generateTriplets(const Module &M, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
-    std::string Relationships;
-    raw_string_ostream RelOS(Relationships);
-
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
-      MaxRelation = std::max(MaxRelation, FuncMaxRelation);
-    }
-
-    RelOS.flush();
-
-    // Write metadata header followed by relationships
-    OS << "MAX_RELATION=" << MaxRelation << '\n';
-    OS << Relationships;
-  }
-
-  /// Generate triplets for a single machine function
-  /// Returns the maximum relation ID used in this function
-  unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
-    unsigned MaxRelation = MIRNextRelation;
-    unsigned PrevOpcode = 0;
-    bool HasPrevOpcode = false;
-
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "MIR Vocabulary must be initialized for triplet generation.\n";
-      return MaxRelation;
-    }
-
-    for (const MachineBasicBlock &MBB : MF) {
-      for (const MachineInstr &MI : MBB) {
-        // Skip debug instructions
-        if (MI.isDebugInstr())
-          continue;
-
-        // Get opcode entity ID
-        unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
-        // Add "Next" relationship with previous instruction
-        if (HasPrevOpcode) {
-          OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
-             << '\n';
-          LLVM_DEBUG(dbgs()
-                     << Vocab->getStringKey(PrevOpcode) << '\t'
-                     << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
-        }
-
-        // Add "Arg" relationships for operands
-        unsigned ArgIndex = 0;
-        for (const MachineOperand &MO : MI.operands()) {
-          auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
-          unsigned RelationID = MIRArgRelation + ArgIndex;
-          OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
-          LLVM_DEBUG({
-            std::string OperandStr = Vocab->getStringKey(OperandID);
-            dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
-                   << '\t' << "Arg" << ArgIndex << '\n';
-          });
-
-          ++ArgIndex;
-        }
-
-        // Update MaxRelation if there were operands
-        if (ArgIndex > 0)
-          MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
-
-        PrevOpcode = OpcodeID;
-        HasPrevOpcode = true;
-      }
-    }
-
-    return MaxRelation;
-  }
-
-  /// Generate entity mappings with vocabulary
-  void generateEntityMappings(raw_ostream &OS) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName)
-          << "Vocabulary must be initialized for entity mappings.\n";
-      return;
-    }
-
-    const unsigned EntityCount = Vocab->getCanonicalSize();
-    OS << EntityCount << "\n";
-    for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
-      OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
-  }
-
-  /// Generate embeddings for all machine functions in the module
-  void generateEmbeddings(const Module &M, raw_ostream &OS) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
-
-    for (const Function &F : M) {
-      if (F.isDeclaration())
-        continue;
-
-      MachineFunction *MF = MMI.getMachineFunction(F);
-      if (!MF) {
-        WithColor::warning(errs(), ToolName)
-            << "No MachineFunction for " << F.getName() << "\n";
-        continue;
-      }
-
-      generateEmbeddings(*MF, OS);
-    }
-  }
-
-  /// Generate embeddings for a specific machine function
-  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
-    if (!Vocab) {
-      WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
-      return;
-    }
-
-    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
-    if (!Emb) {
-      WithColor::error(errs(), ToolName)
-          << "Failed to create embedder for " << MF.getName() << "\n";
-      return;
-    }
-
-    OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
-
-    // Generate embeddings based on the specified level
-    switch (Level) {
-    case FunctionLevel: {
-      OS << "Function vector: ";
-      Emb->getMFunctionVector().print(OS);
-      break;
-    }
-    case BasicBlockLevel: {
-      OS << "Basic block vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        OS << "MBB " << MBB.getName() << ": ";
-        Emb->getMBBVector(MBB).print(OS);
-      }
-      break;
-    }
-    case InstructionLevel: {
-      OS << "Instruction vectors:\n";
-      for (const MachineBasicBlock &MBB : MF) {
-        for (const MachineInstr &MI : MBB) {
-          OS << MI << " -> ";
-          Emb->getMInstVector(MI).print(OS);
-        }
-      }
-      break;
-    }
-    }
-  }
-
-  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
-};
-
-} // namespace mir2vec
-
-} // namespace llvm
-
-int main(int argc, char **argv) {
-  using namespace llvm;
-  using namespace llvm::ir2vec;
-  using namespace llvm::mir2vec;
-
-  InitLLVM X(argc, argv);
-  // Show Common, IR2Vec and MIR2Vec option categories
-  cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{
-      &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
-  cl::ParseCommandLineOptions(
-      argc, argv,
-      "IR2Vec/MIR2Vec - Embedding Generation Tool\n"
-      "Generates embeddings for a given LLVM IR or MIR and "
-      "supports triplet generation for vocabulary "
-      "training and embedding generation.\n\n"
-      "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
-      "information.\n");
-
-  std::error_code EC;
-  raw_fd_ostream OS(OutputFilename, EC);
-  if (EC) {
-    WithColor::error(errs(), ToolName)
-        << "opening output file: " << EC.message() << "\n";
-    return 1;
-  }
-
-  if (IRMode == IRKind::LLVMIR) {
-    if (EntitiesSubCmd) {
-      // Just dump entity mappings without processing any IR
-      IR2VecTool::generateEntityMappings(OS);
-      return 0;
-    }
-
-    // Parse the input LLVM IR file or stdin
-    SMDiagnostic Err;
-    LLVMContext Context;
-    std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
-    if (!M) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    if (Error Err = processModule(*M, OS)) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
-        WithColor::error(errs(), ToolName) << EIB.message() << "\n";
-      });
-      return 1;
-    }
-    return 0;
-  }
-  if (IRMode == IRKind::MIR) {
-    // Initialize targets for Machine IR processing
-    InitializeAllTargets();
-    InitializeAllTargetMCs();
-    InitializeAllAsmParsers();
-    InitializeAllAsmPrinters();
-    static codegen::RegisterCodeGenFlags CGF;
-
-    // Parse MIR input file
-    SMDiagnostic Err;
-    LLVMContext Context;
-    std::unique_ptr<TargetMachine> TM;
-
-    auto MIR = createMIRParserFromFile(InputFilename, Err, Context);
-    if (!MIR) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
-                             StringRef OldDLStr) -> std::optional<std::string> {
-      std::string IRTargetTriple = DataLayoutTargetTriple.str();
-      Triple TheTriple = Triple(IRTargetTriple);
-      if (TheTriple.getTriple().empty())
-        TheTriple.setTriple(sys::getDefaultTargetTriple());
-      auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
-      if (!TMOrErr) {
-        Err.print(ToolName, errs());
-        exit(1);
-      }
-      TM = std::move(*TMOrErr);
-      return TM->createDataLayout().getStringRepresentation();
-    };
-
-    std::unique_ptr<Module> M = MIR->parseIRModule(SetDataLayout);
-    if (!M) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    // Parse machine functions
-    auto MMI = std::make_unique<MachineModuleInfo>(TM.get());
-    if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) {
-      Err.print(ToolName, errs());
-      return 1;
-    }
-
-    // Create MIR2Vec tool
-    MIR2VecTool Tool(*MMI);
-
-    // Initialize vocabulary. For triplet/entity generation, only layout is
-    // needed For embedding generation, the full vocabulary is needed.
-    //
-    // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires
-    // target-specific information for generating the vocabulary layout. So, we
-    // always initialize the vocabulary in this case.
-    if (TripletsSubCmd || EntitiesSubCmd) {
-      if (!Tool.initializeVocabularyForLayout(*M)) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to initialize MIR2Vec vocabulary for layout.\n";
-        return 1;
-      }
-    } else {
-      if (!Tool.initializeVocabulary(*M)) {
-        WithColor::error(errs(), ToolName)
-            << "Failed to initialize MIR2Vec vocabulary.\n";
-        return 1;
-      }
-    }
-    assert(Tool.getVocabulary() &&
-           "MIR2Vec vocabulary should be initialized at this point");
-    LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
-                      << "Vocabulary dimension: "
-                      << Tool.getVocabulary()->getDimension() << "\n"
-                      << "Vocabulary size: "
-                      << Tool.getVocabulary()->getCanonicalSize() << "\n");
-
-    // Handle subcommands
-    if (TripletsSubCmd) {
-      Tool.generateTriplets(*M, OS);
-    } else if (EntitiesSubCmd) {
-      Tool.generateEntityMappings(OS);
-    } else if (EmbeddingsSubCmd) {
-      if (!FunctionName.empty()) {
-        // Process single function
-        Function *F = M->getFunction(FunctionName);
-        if (!F) {
-          WithColor::error(errs(), ToolName)
-              << "Function '" << FunctionName << "' not found\n";
-          return 1;
-        }
-
-        MachineFunction *MF = MMI->getMachineFunction(*F);
-        if (!MF) {
-          WithColor::error(errs(), ToolName)
-              << "No MachineFunction for " << FunctionName << "\n";
-          return 1;
-        }
-
-        Tool.generateEmbeddings(*MF, OS);
-      } else {
-        // Process all functions
-        Tool.generateEmbeddings(*M, OS);
-      }
-    } else {
-      WithColor::error(errs(), ToolName)
-          << "Please specify a subcommand: triplets, entities, or embeddings\n";
-      return 1;
-    }
-
-    return 0;
-  }
-
-  return 0;
-}
diff --git a/llvm/tools/llvm-ir2vec/src/IR2VecTool.cpp b/llvm/tools/llvm-ir2vec/src/IR2VecTool.cpp
new file mode 100644
index 0000000000000..fdc130be333d2
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/src/IR2VecTool.cpp
@@ -0,0 +1,302 @@
+//===- IR2VecTool.cpp - IR2Vec Tool Implementation ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IR2VecTool.h"
+#include "llvm/Demangle/Demangle.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include <cxxabi.h>
+
+#define DEBUG_TYPE "ir2vec"
+
+using namespace llvm;
+using namespace llvm::ir2vec;
+
+namespace llvm {
+extern cl::opt<EmbeddingLevel> Level;
+} // namespace llvm
+
+namespace llvm {
+namespace ir2vec {
+
+bool IR2VecTool::initializeVocabulary() {
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+  Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+  return Vocab && Vocab->isValid();
+}
+
+TripletResult IR2VecTool::getTriplets(const Function &F) const {
+  TripletResult Result;
+  Result.MaxRelation = 0;
+
+  if (F.isDeclaration())
+    return Result;
+
+  unsigned MaxRelation = NextRelation;
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+
+  for (const BasicBlock &BB : F) {
+    for (const auto &I : BB.instructionsWithoutDebug()) {
+      unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+      unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
+
+      if (HasPrevOpcode) {
+        Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
+        LLVM_DEBUG(dbgs()
+                   << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
+                   << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                   << "Next\n");
+      }
+
+      Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
+      LLVM_DEBUG(
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
+                 << '\t' << "Type\n");
+
+      unsigned ArgIndex = 0;
+      for (const Use &U : I.operands()) {
+        unsigned OperandID = Vocabulary::getIndex(*U.get());
+        unsigned RelationID = ArgRelation + ArgIndex;
+        Result.Triplets.push_back({Opcode, OperandID, RelationID});
+
+        LLVM_DEBUG({
+          StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
+              Vocabulary::getOperandKind(U.get()));
+          dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
+                 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+
+      if (ArgIndex > 0) {
+        MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
+      }
+
+      PrevOpcode = Opcode;
+      HasPrevOpcode = true;
+    }
+  }
+
+  Result.MaxRelation = MaxRelation;
+  return Result;
+}
+
+TripletResult IR2VecTool::getTriplets() const {
+  TripletResult Result;
+  Result.MaxRelation = NextRelation;
+
+  for (const Function &F : M) {
+    TripletResult FuncResult = getTriplets(F);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
+    Result.Triplets.insert(Result.Triplets.end(),
+                           FuncResult.Triplets.begin(),
+                           FuncResult.Triplets.end());
+  }
+
+  return Result;
+}
+
+void IR2VecTool::generateTriplets(raw_ostream &OS) const {
+  auto Result = getTriplets();
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+  for (const auto &T : Result.Triplets) {
+    OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
+  }
+}
+
+EntityMap IR2VecTool::getEntityMappings() {
+  auto EntityLen = Vocabulary::getCanonicalSize();
+  EntityMap Result;
+  Result.reserve(EntityLen);
+
+  for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
+    Result.push_back(Vocabulary::getStringKey(EntityID).str());
+
+  return Result;
+}
+
+void IR2VecTool::generateEntityMappings(raw_ostream &OS) {
+  auto Entities = getEntityMappings();
+  OS << Entities.size() << "\n";
+  for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+}
+
+std::pair<std::string, std::pair<std::string, Embedding>>
+IR2VecTool::getFunctionEmbedding(const Function &F) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  if (F.isDeclaration())
+    return {};
+
+  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+  if (!Emb) {
+    return {};
+  }
+
+  auto FuncVec = Emb->getFunctionVector();
+  auto DemangledName = getDemagledName(&F);
+  auto ActualName = getActualName(&F);
+
+  return {std::move(DemangledName), {std::move(ActualName), std::move(FuncVec)}};
+}
+
+FuncVecMap IR2VecTool::getFunctionEmbeddings() const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  FuncVecMap Result;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    auto Emb = getFunctionEmbedding(F);
+    if (!Emb.first.empty()) {
+      Result.try_emplace(
+        std::move(Emb.first),
+        std::move(Emb.second.first),
+        std::move(Emb.second.second)
+      );
+    }
+  }
+
+  return Result;
+}
+
+void IR2VecTool::generateEmbeddings(const Function &F, raw_ostream &OS) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  if (F.isDeclaration()) {
+    OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+    return;
+  }
+
+  OS << "Function: " << F.getName() << "\n";
+
+  auto printError = [&]() {
+    OS << "Error: Failed to create embedder for function " << F.getName() << '\n';
+  };
+
+  auto printListLevel = [&](const auto& list) {
+    if (list.empty()) return printError();
+    for (const auto& [name, embedding] : list) {
+      OS << name;
+      embedding.print(OS);
+      OS << '\n';
+    }
+  };
+
+  switch (Level) {
+    case EmbeddingLevel::FunctionLevel:
+      if (auto FuncEmb = getFunctionEmbedding(F); !FuncEmb.first.empty())
+        FuncEmb.second.second.print(OS);
+      else printError();
+      break;
+    case EmbeddingLevel::BasicBlockLevel:
+      printListLevel(getBBEmbeddings(F));
+      break;
+    case EmbeddingLevel::InstructionLevel:
+      printListLevel(getInstEmbeddings(F));
+      break;
+  }
+}
+
+void IR2VecTool::generateEmbeddings(raw_ostream &OS) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  for (const Function &F : M)
+    generateEmbeddings(F, OS);
+}
+
+BBVecList IR2VecTool::getBBEmbeddings(const Function &F) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  BBVecList Result;
+
+  if (F.isDeclaration())
+    return Result;
+
+  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+  if (!Emb)
+    return Result;
+
+  for (const BasicBlock &BB : F)
+    Result.push_back({BB.getName().str(), Emb->getBBVector(BB)});
+
+  return Result;
+}
+
+BBVecList IR2VecTool::getBBEmbeddings() const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  BBVecList Result;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration()) continue;
+
+    BBVecList FuncBBVecs = getBBEmbeddings(F);
+    Result.insert(Result.end(), FuncBBVecs.begin(), FuncBBVecs.end());
+  }
+
+  return Result;
+}
+
+InstVecList IR2VecTool::getInstEmbeddings(const Function &F) const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  InstVecList Result;
+
+  if (F.isDeclaration())
+    return Result;
+
+  auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
+  if (!Emb)
+    return Result;
+
+  for (const Instruction &I : instructions(F)) {
+    std::string InstStr = [&]() {
+      std::string str;
+      raw_string_ostream RSO(str);
+      I.print(RSO);
+      RSO.flush();
+      return str;
+    }();
+
+    Result.push_back({InstStr, Emb->getInstVector(I)});
+  }
+
+  return Result;
+}
+
+InstVecList IR2VecTool::getInstEmbeddings() const {
+  assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+  InstVecList Result;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    InstVecList FuncInstVecs = getInstEmbeddings(F);
+    Result.insert(Result.end(), FuncInstVecs.begin(), FuncInstVecs.end());
+  }
+
+  return Result;
+}
+
+} // namespace ir2vec
+} // namespace llvm
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/src/MIR2VecTool.cpp b/llvm/tools/llvm-ir2vec/src/MIR2VecTool.cpp
new file mode 100644
index 0000000000000..9151af05bb59c
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/src/MIR2VecTool.cpp
@@ -0,0 +1,489 @@
+//===- MIR2VecTool.cpp - MIR2Vec Tool Implementation ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the MIR2Vec tool for Machine IR embedding generation.
+///
+//===----------------------------------------------------------------------===//
+
+#include "MIR2VecTool.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+
+#define DEBUG_TYPE "mir2vec"
+
+namespace llvm {
+// Generate embeddings based on the specified level
+// Note: Level is expected to be a global variable accessible here
+// (This matches the original design where Level is in CommonCategory)
+extern cl::opt<llvm::EmbeddingLevel> Level;
+
+namespace mir2vec {
+
+static const char *ToolName = "llvm-ir2vec";
+
+// ========================================================================
+// MIR Context Setup Functions
+// ========================================================================
+
+void initializeTargets() {
+  static bool Initialized = false;
+  if (!Initialized) {
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+    InitializeAllAsmParsers();
+    InitializeAllAsmPrinters();
+    Initialized = true;
+  }
+}
+
+static Error parseMIRFile(const std::string &Filename, MIRContext &Ctx) {
+  SMDiagnostic Err;
+
+  Ctx.Parser = createMIRParserFromFile(Filename, Err, Ctx.Context);
+  if (!Ctx.Parser) {
+    std::string ErrMsg;
+    raw_string_ostream OS(ErrMsg);
+    Err.print(ToolName, OS);
+    return createStringError(errc::invalid_argument, OS.str());
+  }
+
+  auto SetDataLayout = [&Ctx](StringRef DataLayoutTargetTriple,
+                              StringRef OldDLStr) -> std::optional<std::string> {
+    std::string IRTargetTriple = DataLayoutTargetTriple.str();
+    Triple TheTriple = Triple(IRTargetTriple);
+    if (TheTriple.getTriple().empty())
+      TheTriple.setTriple(sys::getDefaultTargetTriple());
+
+    auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
+    if (!TMOrErr) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create target machine: "
+          << toString(TMOrErr.takeError()) << "\n";
+      exit(1);
+    }
+
+    Ctx.TM = std::move(*TMOrErr);
+    return Ctx.TM->createDataLayout().getStringRepresentation();
+  };
+
+  Ctx.M = Ctx.Parser->parseIRModule(SetDataLayout);
+  if (!Ctx.M) {
+    return createStringError(errc::invalid_argument,
+                            "Failed to parse IR module from MIR file");
+  }
+
+  return Error::success();
+}
+
+static Error parseMachineFunctions(MIRContext &Ctx) {
+  SMDiagnostic Err;
+  Ctx.MMI = std::make_unique<MachineModuleInfo>(Ctx.TM.get());
+  if (!Ctx.MMI) {
+    return createStringError(errc::not_enough_memory,
+                            "Failed to create MachineModuleInfo");
+  }
+
+  if (Ctx.Parser->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) {
+    return createStringError(errc::invalid_argument,
+                            "Failed to parse machine functions");
+  }
+
+  return Error::success();
+}
+
+Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
+  initializeTargets();
+
+  if (auto Err = parseMIRFile(InputFile, Ctx))
+    return Err;
+
+  if (auto Err = parseMachineFunctions(Ctx))
+    return Err;
+
+  return Error::success();
+}
+
+// ========================================================================
+// MIR2VecTool Implementation
+// ========================================================================
+
+bool MIR2VecTool::initializeVocabulary(const Module &M) {
+  MIR2VecVocabProvider Provider(MMI);
+  auto VocabOrErr = Provider.getVocabulary(M);
+  if (!VocabOrErr) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to load MIR2Vec vocabulary - "
+        << toString(VocabOrErr.takeError()) << "\n";
+    return false;
+  }
+  Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+  return true;
+}
+
+/// Initialize vocabulary with layout information only.
+/// This creates a minimal vocabulary with correct layout but no actual
+/// embeddings. Sufficient for generating training data and entity mappings.
+///
+/// Note: Requires target-specific information from the first machine function
+/// to determine the vocabulary layout (number of opcodes, register classes).
+///
+/// FIXME: Use --target option to get target info directly, avoiding the need
+/// to parse machine functions for pre-training operations.
+bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF)
+      continue;
+
+    const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+    const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+    const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+    auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+    if (!VocabOrErr) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create dummy vocabulary - "
+          << toString(VocabOrErr.takeError()) << "\n";
+      return false;
+    }
+    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+    return true;
+  }
+
+  WithColor::error(errs(), ToolName)
+      << "No machine functions found to initialize vocabulary\n";
+  return false;
+}
+
+// ========================================================================
+// Data structure methods
+// ========================================================================
+
+unsigned MIR2VecTool::generateTripletsForMF(
+    const MachineFunction &MF, std::vector<Triplet> &Triplets) const {
+  unsigned MaxRelation = MIRNextRelation;
+  unsigned PrevOpcode = 0;
+  bool HasPrevOpcode = false;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "MIR Vocabulary must be initialized for triplet generation.\n";
+    return MaxRelation;
+  }
+
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions
+      if (MI.isDebugInstr())
+        continue;
+
+      // Get opcode entity ID
+      unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+      // Add "Next" relationship with previous instruction
+      if (HasPrevOpcode) {
+        Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
+        LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
+                          << Vocab->getStringKey(OpcodeID) << '\t'
+                          << "Next\n");
+      }
+
+      // Add "Arg" relationships for operands
+      unsigned ArgIndex = 0;
+      for (const MachineOperand &MO : MI.operands()) {
+        auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+        unsigned RelationID = MIRArgRelation + ArgIndex;
+        Triplets.push_back({OpcodeID, OperandID, RelationID});
+        LLVM_DEBUG({
+          std::string OperandStr = Vocab->getStringKey(OperandID);
+          dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
+                 << "Arg" << ArgIndex << '\n';
+        });
+
+        ++ArgIndex;
+      }
+
+      // Update MaxRelation if there were operands
+      if (ArgIndex > 0)
+        MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+      PrevOpcode = OpcodeID;
+      HasPrevOpcode = true;
+    }
+  }
+
+  return MaxRelation;
+}
+
+TripletResult MIR2VecTool::getTriplets(const Module &M) const {
+  TripletResult Result;
+  Result.MaxRelation = MIRNextRelation;
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    unsigned FuncMaxRelation = generateTripletsForMF(*MF, Result.Triplets);
+    Result.MaxRelation = std::max(Result.MaxRelation, FuncMaxRelation);
+  }
+
+  return Result;
+}
+
+TripletResult MIR2VecTool::getTriplets(const MachineFunction &MF) const {
+  TripletResult Result;
+  Result.MaxRelation = generateTripletsForMF(MF, Result.Triplets);
+  return Result;
+}
+
+EntityMap MIR2VecTool::getEntityMappings() const {
+  EntityMap Entities;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName)
+        << "Vocabulary must be initialized for entity mappings.\n";
+    return Entities;
+  }
+
+  const unsigned EntityCount = Vocab->getCanonicalSize();
+  Entities.reserve(EntityCount);
+
+  for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+    Entities.push_back(Vocab->getStringKey(EntityID));
+
+  return Entities;
+}
+
+FuncVecMap MIR2VecTool::getFunctionEmbeddings(const Module &M) const {
+  FuncVecMap FuncEmbeddings;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return FuncEmbeddings;
+  }
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, *Vocab);
+    if (!Emb) {
+      WithColor::error(errs(), ToolName)
+          << "Failed to create embedder for " << MF->getName() << "\n";
+      continue;
+    }
+
+    auto DemangledName = getDemagledName(&F);
+    auto ActualName = getActualName(&F);
+    FuncEmbeddings[DemangledName] = {ActualName, Emb->getMFunctionVector()};
+  }
+
+  return FuncEmbeddings;
+}
+
+std::pair<std::string, Embedding>
+MIR2VecTool::getFunctionEmbedding(MachineFunction &MF) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return {"", Embedding()};
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return {"", Embedding()};
+  }
+
+  return {MF.getName().str(), Emb->getMFunctionVector()};
+}
+
+BBVecList MIR2VecTool::getMBBEmbeddings(MachineFunction &MF) const {
+  BBVecList BBEmbeddings;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return BBEmbeddings;
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return BBEmbeddings;
+  }
+
+  for (const MachineBasicBlock &MBB : MF) {
+    std::string BBName = MBB.getName().str();
+    if (BBName.empty())
+      BBName = "bb." + std::to_string(MBB.getNumber());
+    BBEmbeddings.push_back({BBName, Emb->getMBBVector(MBB)});
+  }
+
+  return BBEmbeddings;
+}
+
+InstVecList MIR2VecTool::getMInstEmbeddings(MachineFunction &MF) const {
+  InstVecList InstEmbeddings;
+
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return InstEmbeddings;
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return InstEmbeddings;
+  }
+
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      std::string InstStr;
+      raw_string_ostream OS(InstStr);
+      OS << MI;
+      InstEmbeddings.push_back({OS.str(), Emb->getMInstVector(MI)});
+    }
+  }
+
+  return InstEmbeddings;
+}
+
+// ========================================================================
+// Stream output methods
+// ========================================================================
+
+void MIR2VecTool::generateTriplets(const Module &M, raw_ostream &OS) const {
+  auto Result = getTriplets(M);
+
+  // Write metadata header followed by relationships
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+
+  for (const auto &Triplet : Result.Triplets) {
+    OS << Triplet.Head << '\t' << Triplet.Tail << '\t' << Triplet.Relation
+       << '\n';
+  }
+}
+
+void MIR2VecTool::generateTriplets(const MachineFunction &MF,
+                                   raw_ostream &OS) const {
+  auto Result = getTriplets(MF);
+
+  // Write metadata header followed by relationships
+  OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
+
+  for (const auto &Triplet : Result.Triplets) {
+    OS << Triplet.Head << '\t' << Triplet.Tail << '\t' << Triplet.Relation
+       << '\n';
+  }
+}
+
+void MIR2VecTool::generateEntityMappings(raw_ostream &OS) const {
+  auto Entities = getEntityMappings();
+
+  OS << Entities.size() << "\n";
+  for (size_t EntityID = 0; EntityID < Entities.size(); ++EntityID) {
+    OS << Entities[EntityID] << '\t' << EntityID << '\n';
+  }
+}
+
+void MIR2VecTool::generateEmbeddings(const Module &M, raw_ostream &OS) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  for (const Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    MachineFunction *MF = MMI.getMachineFunction(F);
+    if (!MF) {
+      WithColor::warning(errs(), ToolName)
+          << "No MachineFunction for " << F.getName() << "\n";
+      continue;
+    }
+
+    generateEmbeddings(*MF, OS);
+  }
+}
+
+void MIR2VecTool::generateEmbeddings(MachineFunction &MF,
+                                     raw_ostream &OS) const {
+  if (!Vocab) {
+    WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
+    return;
+  }
+
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+  if (!Emb) {
+    WithColor::error(errs(), ToolName)
+        << "Failed to create embedder for " << MF.getName() << "\n";
+    return;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+  switch (Level) {
+  case FunctionLevel: {
+    OS << "Function vector: ";
+    Emb->getMFunctionVector().print(OS);
+    break;
+  }
+  case BasicBlockLevel: {
+    OS << "Basic block vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      OS << "MBB " << MBB.getName() << ": ";
+      Emb->getMBBVector(MBB).print(OS);
+    }
+    break;
+  }
+  case InstructionLevel: {
+    OS << "Instruction vectors:\n";
+    for (const MachineBasicBlock &MBB : MF) {
+      for (const MachineInstr &MI : MBB) {
+        OS << MI << " -> ";
+        Emb->getMInstVector(MI).print(OS);
+      }
+    }
+    break;
+  }
+  }
+}
+
+} // namespace mir2vec
+} // namespace llvm
\ No newline at end of file
diff --git a/llvm/tools/llvm-ir2vec/src/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/src/llvm-ir2vec.cpp
new file mode 100644
index 0000000000000..652f3ff43abdd
--- /dev/null
+++ b/llvm/tools/llvm-ir2vec/src/llvm-ir2vec.cpp
@@ -0,0 +1,359 @@
+//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
+///
+/// This tool supports two modes:
+/// - LLVM IR mode (-mode=llvm): Process LLVM IR
+/// - Machine IR mode (-mode=mir): Process Machine IR
+///
+/// Available subcommands:
+///
+/// 1. Triplet Generation (triplets):
+///    Generates numeric triplets (head, tail, relation) for vocabulary
+///    training. Output format: MAX_RELATION=N header followed by
+///    head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
+///
+///    For LLVM IR:
+///      llvm-ir2vec triplets input.bc -o train2id.txt
+///
+///    For Machine IR:
+///      llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
+///
+/// 2. Entity Mappings (entities):
+///    Generates entity mappings for vocabulary training.
+///    Output format: <total_entities> header followed by entity\tid lines.
+///
+///    For LLVM IR:
+///      llvm-ir2vec entities input.bc -o entity2id.txt
+///
+///    For Machine IR:
+///      llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
+///
+/// 3. Embedding Generation (embeddings):
+///    Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
+///
+///    For LLVM IR:
+///      llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
+///        --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
+///      Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
+///
+///    For Machine IR:
+///      llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
+///        --level=<level> input.mir -o embeddings.txt
+///
+///    Levels: --level=inst (instructions), --level=bb (basic blocks),
+///    --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
+///    generation options)
+///
+//===----------------------------------------------------------------------===//
+
+#include "IR2VecTool.h"
+#include "MIR2VecTool.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+
+#define DEBUG_TYPE "ir2vec"
+
+
+namespace llvm {
+
+static const char *ToolName = "llvm-ir2vec";
+
+// Common option category for options shared between IR2Vec and MIR2Vec
+static cl::OptionCategory CommonCategory("Common Options",
+                                         "Options applicable to both IR2Vec "
+                                         "and MIR2Vec modes");
+
+enum IRKind {
+  LLVMIR = 0, ///< LLVM IR
+  MIR         ///< Machine IR
+};
+
+static cl::opt<IRKind>
+    IRMode("mode", cl::desc("Tool operation mode:"),
+           cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
+                      clEnumValN(MIR, "mir", "Process Machine IR")),
+           cl::init(LLVMIR), cl::cat(CommonCategory));
+
+// Subcommands
+static cl::SubCommand
+    TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
+static cl::SubCommand
+    EntitiesSubCmd("entities",
+                   "Generate entity mappings for vocabulary training");
+static cl::SubCommand
+    EmbeddingsSubCmd("embeddings",
+                     "Generate embeddings using trained vocabulary");
+
+// Common options
+static cl::opt<std::string> InputFilename(
+    cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
+    cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
+    cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
+
+static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
+                                           cl::value_desc("filename"),
+                                           cl::init("-"),
+                                           cl::cat(CommonCategory));
+
+// Embedding-specific options
+static cl::opt<std::string>
+    FunctionName("function", cl::desc("Process specific function only"),
+                 cl::value_desc("name"), cl::Optional, cl::init(""),
+                 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
+
+cl::opt<EmbeddingLevel>
+    Level("level", cl::desc("Embedding generation level:"),
+          cl::values(clEnumValN(EmbeddingLevel::InstructionLevel, "inst",
+                                "Generate instruction-level embeddings"),
+                     clEnumValN(EmbeddingLevel::BasicBlockLevel, "bb",
+                                "Generate basic block-level embeddings"),
+                     clEnumValN(EmbeddingLevel::FunctionLevel, "func",
+                                "Generate function-level embeddings")),
+          cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
+          cl::cat(CommonCategory));
+
+namespace ir2vec {
+Error processModule(Module &M, raw_ostream &OS) {
+  IR2VecTool Tool(M);
+
+  if (EmbeddingsSubCmd) {
+    // Initialize vocabulary for embedding generation
+    // Note: Requires --ir2vec-vocab-path option to be set
+    auto VocabStatus = Tool.initializeVocabulary();
+    assert(VocabStatus && "Failed to initialize IR2Vec vocabulary");
+    (void)VocabStatus;
+
+    if (!FunctionName.empty()) {
+      if (const Function *F = M.getFunction(FunctionName))
+        Tool.generateEmbeddings(*F, OS);
+      else
+        return createStringError(errc::invalid_argument,
+                                 "Function '%s' not found",
+                                 FunctionName.c_str());
+    } else {
+      Tool.generateEmbeddings(OS);
+    }
+  } else {
+    Tool.generateTriplets(OS);
+  }
+  return Error::success();
+}
+} // namespace ir2vec
+
+namespace mir2vec {
+
+/// Process module for triplet generation
+Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
+  MIR2VecTool Tool(*Ctx.MMI);
+
+  if (!Tool.initializeVocabularyForLayout(*Ctx.M)) {
+    return createStringError(errc::invalid_argument,
+                            "Failed to initialize MIR2Vec vocabulary for layout");
+  }
+
+  assert(Tool.getVocabulary() &&
+         "MIR2Vec vocabulary should be initialized at this point");
+
+  LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+                    << "Vocabulary dimension: "
+                    << Tool.getVocabulary()->getDimension() << "\n"
+                    << "Vocabulary size: "
+                    << Tool.getVocabulary()->getCanonicalSize() << "\n");
+
+  Tool.generateTriplets(*Ctx.M, OS);
+  return Error::success();
+}
+
+/// Process module for entity generation
+Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
+  MIR2VecTool Tool(*Ctx.MMI);
+
+  if (!Tool.initializeVocabularyForLayout(*Ctx.M)) {
+    return createStringError(errc::invalid_argument,
+                            "Failed to initialize MIR2Vec vocabulary for layout");
+  }
+
+  assert(Tool.getVocabulary() &&
+         "MIR2Vec vocabulary should be initialized at this point");
+
+  LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+                    << "Vocabulary dimension: "
+                    << Tool.getVocabulary()->getDimension() << "\n"
+                    << "Vocabulary size: "
+                    << Tool.getVocabulary()->getCanonicalSize() << "\n");
+
+  Tool.generateEntityMappings(OS);
+  return Error::success();
+}
+
+/// Process module for embedding generation
+Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
+  MIR2VecTool Tool(*Ctx.MMI);
+
+  if (!Tool.initializeVocabulary(*Ctx.M)) {
+    return createStringError(errc::invalid_argument,
+                            "Failed to initialize MIR2Vec vocabulary");
+  }
+
+  assert(Tool.getVocabulary() &&
+         "MIR2Vec vocabulary should be initialized at this point");
+
+  LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+                    << "Vocabulary dimension: "
+                    << Tool.getVocabulary()->getDimension() << "\n"
+                    << "Vocabulary size: "
+                    << Tool.getVocabulary()->getCanonicalSize() << "\n");
+
+  if (!FunctionName.empty()) {
+    // Process single function
+    Function *F = Ctx.M->getFunction(FunctionName);
+    if (!F) {
+      return createStringError(errc::invalid_argument,
+                              "Function '%s' not found",
+                              FunctionName.c_str());
+    }
+
+    MachineFunction *MF = Ctx.MMI->getMachineFunction(*F);
+    if (!MF) {
+      return createStringError(errc::invalid_argument,
+                              "No MachineFunction for '%s'",
+                              FunctionName.c_str());
+    }
+
+    Tool.generateEmbeddings(*MF, OS);
+  } else {
+    // Process all functions
+    Tool.generateEmbeddings(*Ctx.M, OS);
+  }
+
+  return Error::success();
+}
+
+/// Main entry point for MIR processing
+Error processModule(const std::string &InputFile, raw_ostream &OS) {
+  MIRContext Ctx;
+
+  // Setup MIR context (parse file, setup target machine, etc.)
+  if (auto Err = setupMIRContext(InputFile, Ctx))
+    return Err;
+
+  // Process based on subcommand
+  if (TripletsSubCmd) {
+    return processModuleForTriplets(Ctx, OS);
+  } else if (EntitiesSubCmd) {
+    return processModuleForEntities(Ctx, OS);
+  } else if (EmbeddingsSubCmd) {
+    return processModuleForEmbeddings(Ctx, OS);
+  } else {
+    return createStringError(errc::invalid_argument,
+                            "Please specify a subcommand: triplets, entities, or embeddings");
+  }
+}
+
+} // namespace mir2vec
+
+} // namespace llvm
+
+int main(int argc, char **argv) {
+  using namespace llvm;
+  using namespace llvm::ir2vec;
+  using namespace llvm::mir2vec;
+
+  InitLLVM X(argc, argv);
+  // Show Common, IR2Vec and MIR2Vec option categories
+  cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{
+      &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
+  cl::ParseCommandLineOptions(
+      argc, argv,
+      "IR2Vec/MIR2Vec - Embedding Generation Tool\n"
+      "Generates embeddings for a given LLVM IR or MIR and "
+      "supports triplet generation for vocabulary "
+      "training and embedding generation.\n\n"
+      "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
+      "information.\n");
+
+  std::error_code EC;
+  raw_fd_ostream OS(OutputFilename, EC);
+  if (EC) {
+    WithColor::error(errs(), ToolName)
+        << "opening output file: " << EC.message() << "\n";
+    return 1;
+  }
+
+  if (IRMode == IRKind::LLVMIR) {
+    if (EntitiesSubCmd) {
+      // Just dump entity mappings without processing any IR
+      IR2VecTool::generateEntityMappings(OS);
+      return 0;
+    }
+
+    // Parse the input LLVM IR file or stdin
+    SMDiagnostic Err;
+    LLVMContext Context;
+    std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
+    if (!M) {
+      Err.print(ToolName, errs());
+      return 1;
+    }
+
+    if (Error Err = processModule(*M, OS)) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
+        WithColor::error(errs(), ToolName) << EIB.message() << "\n";
+      });
+      return 1;
+    }
+    return 0;
+  }
+  if (IRMode == IRKind::MIR) {
+    // Register codegen flags
+    static codegen::RegisterCodeGenFlags CGF;
+    
+    // Process MIR module
+    if (Error Err = mir2vec::processModule(InputFilename, OS)) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
+        WithColor::error(errs(), ToolName) << EIB.message() << "\n";
+      });
+      return 1;
+    }
+    return 0;
+  }
+
+  return 0;
+}



More information about the llvm-commits mailing list