[llvm] [MIR2Vec][llvm-ir2vec] Add MIR2Vec support to llvm-ir2vec tool (PR #164025)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 11:00:59 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/164025

>From fdfb77c27817dd4d6f6a28d0fead11556c083a1d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Fri, 17 Oct 2025 22:36:25 +0000
Subject: [PATCH] [llvm-ir2vec] MIR2Vec support

---
 llvm/docs/CommandGuide/llvm-ir2vec.rst        |  95 ++++--
 llvm/include/llvm/CodeGen/MIR2Vec.h           |  57 +++-
 llvm/lib/CodeGen/MIR2Vec.cpp                  |  91 +++---
 .../tools/llvm-ir2vec/embeddings-symbolic.mir |  92 ++++++
 .../test/tools/llvm-ir2vec/error-handling.mir |  41 +++
 llvm/tools/llvm-ir2vec/CMakeLists.txt         |  14 +
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        | 288 +++++++++++++++---
 7 files changed, 562 insertions(+), 116 deletions(-)
 create mode 100644 llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir
 create mode 100644 llvm/test/tools/llvm-ir2vec/error-handling.mir

diff --git a/llvm/docs/CommandGuide/llvm-ir2vec.rst b/llvm/docs/CommandGuide/llvm-ir2vec.rst
index fc590a6180316..55fe75d2084b1 100644
--- a/llvm/docs/CommandGuide/llvm-ir2vec.rst
+++ b/llvm/docs/CommandGuide/llvm-ir2vec.rst
@@ -1,5 +1,5 @@
-llvm-ir2vec - IR2Vec Embedding Generation Tool
-==============================================
+llvm-ir2vec - IR2Vec and MIR2Vec Embedding Generation Tool
+===========================================================
 
 .. program:: llvm-ir2vec
 
@@ -11,9 +11,9 @@ SYNOPSIS
 DESCRIPTION
 -----------
 
-:program:`llvm-ir2vec` is a standalone command-line tool for IR2Vec. It
-generates IR2Vec embeddings for LLVM IR and supports triplet generation 
-for vocabulary training. 
+:program:`llvm-ir2vec` is a standalone command-line tool for IR2Vec and MIR2Vec.
+It generates embeddings for both LLVM IR and Machine IR (MIR) and supports 
+triplet generation for vocabulary training. 
 
 The tool provides three main subcommands:
 
@@ -23,23 +23,33 @@ The tool provides three main subcommands:
 2. **entities**: Generates entity mapping files (entity2id.txt) for vocabulary 
    training.
 
-3. **embeddings**: Generates IR2Vec embeddings using a trained vocabulary
+3. **embeddings**: Generates IR2Vec or MIR2Vec embeddings using a trained vocabulary
    at different granularity levels (instruction, basic block, or function).
 
+The tool supports two operation modes:
+
+* **LLVM IR mode** (``--mode=llvm``): Process LLVM IR bitcode files and generate
+  IR2Vec embeddings
+* **Machine IR mode** (``--mode=mir``): Process Machine IR (.mir) files and generate
+  MIR2Vec embeddings
+
 The tool is designed to facilitate machine learning applications that work with
-LLVM IR by converting the IR into numerical representations that can be used by
-ML models. The `triplets` subcommand generates numeric IDs directly instead of string 
-triplets, streamlining the training data preparation workflow.
+LLVM IR or Machine IR by converting them into numerical representations that can 
+be used by ML models. The `triplets` subcommand generates numeric IDs directly 
+instead of string triplets, streamlining the training data preparation workflow.
 
 .. note::
 
-   For information about using IR2Vec programmatically within LLVM passes and 
-   the C++ API, see the `IR2Vec Embeddings <https://llvm.org/docs/MLGO.html#ir2vec-embeddings>`_ 
+   For information about using IR2Vec and MIR2Vec programmatically within LLVM 
+   passes and the C++ API, see the `IR2Vec Embeddings <https://llvm.org/docs/MLGO.html#ir2vec-embeddings>`_ 
    section in the MLGO documentation.
 
 OPERATION MODES
 ---------------
 
+The tool operates in two modes: **LLVM IR mode** and **Machine IR mode**. The mode
+is selected using the ``--mode`` option (default: ``llvm``).
+
 Triplet Generation and Entity Mapping Modes are used for preparing
 vocabulary and training data for knowledge graph embeddings. The Embedding Mode
 is used for generating embeddings from LLVM IR using a pre-trained vocabulary.
@@ -89,18 +99,31 @@ Embedding Generation
 ~~~~~~~~~~~~~~~~~~~~
 
 With the `embeddings` subcommand, :program:`llvm-ir2vec` uses a pre-trained vocabulary to
-generate numerical embeddings for LLVM IR at different levels of granularity.
+generate numerical embeddings for LLVM IR or Machine IR at different levels of granularity.
+
+Example Usage for LLVM IR:
+
+.. code-block:: bash
+
+   llvm-ir2vec embeddings --mode=llvm --ir2vec-vocab-path=vocab.json --ir2vec-kind=symbolic --level=func input.bc -o embeddings.txt
 
-Example Usage:
+Example Usage for Machine IR:
 
 .. code-block:: bash
 
-   llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json --ir2vec-kind=symbolic --level=func input.bc -o embeddings.txt
+   llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=vocab.json --level=func input.mir -o embeddings.txt
 
 OPTIONS
 -------
 
-Global options:
+Common options (applicable to both LLVM IR and Machine IR modes):
+
+.. option:: --mode=<mode>
+
+   Specify the operation mode. Valid values are:
+
+   * ``llvm`` - Process LLVM IR bitcode files (default)
+   * ``mir`` - Process Machine IR (.mir) files
 
 .. option:: -o <filename>
 
@@ -116,8 +139,8 @@ Subcommand-specific options:
 
 .. option:: <input-file>
 
-   The input LLVM IR or bitcode file to process. This positional argument is
-   required for the `embeddings` subcommand.
+   The input LLVM IR/bitcode file (.ll/.bc) or Machine IR file (.mir) to process. 
+   This positional argument is required for the `embeddings` subcommand.
 
 .. option:: --level=<level>
 
@@ -131,6 +154,8 @@ Subcommand-specific options:
 
    Process only the specified function instead of all functions in the module.
 
+**IR2Vec-specific options** (for ``--mode=llvm``):
+
 .. option:: --ir2vec-kind=<kind>
 
    Specify the kind of IR2Vec embeddings to generate. Valid values are:
@@ -143,8 +168,8 @@ Subcommand-specific options:
 
 .. option:: --ir2vec-vocab-path=<path>
 
-   Specify the path to the vocabulary file (required for embedding generation).
-   The vocabulary file should be in JSON format and contain the trained
+   Specify the path to the IR2Vec vocabulary file (required for LLVM IR embedding 
+   generation). The vocabulary file should be in JSON format and contain the trained
    vocabulary for embedding generation. See `llvm/lib/Analysis/models`
    for pre-trained vocabulary files.
 
@@ -163,6 +188,35 @@ Subcommand-specific options:
    Specify the weight for argument embeddings (default: 0.2). This controls
    the relative importance of operand information in the final embedding.
 
+**MIR2Vec-specific options** (for ``--mode=mir``):
+
+.. option:: --mir2vec-vocab-path=<path>
+
+   Specify the path to the MIR2Vec vocabulary file (required for Machine IR 
+   embedding generation). The vocabulary file should be in JSON format and 
+   contain the trained vocabulary for embedding generation.
+
+.. option:: --mir2vec-kind=<kind>
+
+   Specify the kind of MIR2Vec embeddings to generate. Valid values are:
+
+   * ``symbolic`` - Generate symbolic embeddings (default)
+
+.. option:: --mir2vec-opc-weight=<weight>
+
+   Specify the weight for machine opcode embeddings (default: 1.0). This controls
+   the relative importance of machine instruction opcodes in the final embedding.
+
+.. option:: --mir2vec-common-operand-weight=<weight>
+
+   Specify the weight for common operand embeddings (default: 1.0). This controls
+   the relative importance of common operand types in the final embedding.
+
+.. option:: --mir2vec-reg-operand-weight=<weight>
+
+   Specify the weight for register operand embeddings (default: 1.0). This controls
+   the relative importance of register operands in the final embedding.
+
 
 **triplets** subcommand:
 
@@ -240,3 +294,6 @@ SEE ALSO
 
 For more information about the IR2Vec algorithm and approach, see:
 `IR2Vec: LLVM IR Based Scalable Program Embeddings <https://doi.org/10.1145/3418463>`_.
+
+For more information about the MIR2Vec algorithm and approach, see:
+`RL4ReAl: Reinforcement Learning for Register Allocation <https://doi.org/10.1145/3578360.3580273>`_.
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 4bcbad7b53082..bc7d0e522a3ae 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -7,9 +7,20 @@
 //===----------------------------------------------------------------------===//
 ///
 /// \file
-/// This file defines the MIR2Vec vocabulary
-/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
-/// interface for generating Machine IR embeddings, and related utilities.
+/// This file defines the MIR2Vec framework for generating Machine IR
+/// embeddings.
+///
+/// Architecture Overview:
+/// ----------------------
+/// 1. MIR2VecVocabProvider - Core vocabulary loading logic (no PM dependency)
+///    - Can be used standalone or wrapped by the pass manager
+///    - Requires MachineModuleInfo with parsed machine functions
+///
+/// 2. MIR2VecVocabLegacyAnalysis - Pass manager wrapper (ImmutablePass)
+///    - Integrated and used by llc -print-mir2vec
+///
+/// 3. MIREmbedder - Generates embeddings from vocabulary
+///    - SymbolicMIREmbedder: MIR2Vec embedding implementation
 ///
 /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
 /// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -306,26 +317,58 @@ class SymbolicMIREmbedder : public MIREmbedder {
 
 } // namespace mir2vec
 
+/// MIR2Vec vocabulary provider used by pass managers and standalone tools.
+/// This class encapsulates the core vocabulary loading logic and can be used
+/// independently of the pass manager infrastructure. For pass-based usage,
+/// see MIR2VecVocabLegacyAnalysis.
+///
+/// Note: This provider pattern makes new PM migration straightforward when
+/// needed. A new PM analysis wrapper can be added that delegates to this
+/// provider, similar to how MIR2VecVocabLegacyAnalysis currently wraps it.
+class MIR2VecVocabProvider {
+  using VocabMap = std::map<std::string, mir2vec::Embedding>;
+
+public:
+  MIR2VecVocabProvider(const MachineModuleInfo &MMI) : MMI(MMI) {}
+
+  Expected<mir2vec::MIRVocabulary> getVocabulary(const Module &M);
+
+private:
+  Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
+                       VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
+  const MachineModuleInfo &MMI;
+};
+
 /// Pass to analyze and populate MIR2Vec vocabulary from a module
 class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
   using VocabVector = std::vector<mir2vec::Embedding>;
   using VocabMap = std::map<std::string, mir2vec::Embedding>;
-  std::optional<mir2vec::MIRVocabulary> Vocab;
 
   StringRef getPassName() const override;
-  Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
-                       VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
 
 protected:
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<MachineModuleInfoWrapperPass>();
     AU.setPreservesAll();
   }
+  std::unique_ptr<MIR2VecVocabProvider> Provider;
 
 public:
   static char ID;
   MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
-  Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
+
+  Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M) {
+    MachineModuleInfo &MMI =
+        getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
+    if (!Provider)
+      Provider = std::make_unique<MIR2VecVocabProvider>(MMI);
+    return Provider->getVocabulary(M);
+  }
+
+  MIR2VecVocabProvider &getProvider() {
+    assert(Provider && "Provider not initialized");
+    return *Provider;
+  }
 };
 
 /// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 75ca06a79eaa5..00b37e7032f61 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -417,24 +417,39 @@ Expected<MIRVocabulary> MIRVocabulary::createDummyVocabForTest(
 }
 
 //===----------------------------------------------------------------------===//
-// MIR2VecVocabLegacyAnalysis Implementation
+// MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis
 //===----------------------------------------------------------------------===//
 
-char MIR2VecVocabLegacyAnalysis::ID = 0;
-INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
-                      "MIR2Vec Vocabulary Analysis", false, true)
-INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
-INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
-                    "MIR2Vec Vocabulary Analysis", false, true)
+Expected<mir2vec::MIRVocabulary>
+MIR2VecVocabProvider::getVocabulary(const Module &M) {
+  VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap;
 
-StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
-  return "MIR2Vec Vocabulary Analysis";
+  if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap,
+                                 VirtRegVocabMap))
+    return std::move(Err);
+
+  for (const auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    if (auto *MF = MMI.getMachineFunction(F)) {
+      auto &Subtarget = MF->getSubtarget();
+      if (const auto *TII = Subtarget.getInstrInfo())
+        if (const auto *TRI = Subtarget.getRegisterInfo())
+          return mir2vec::MIRVocabulary::create(
+              std::move(OpcVocab), std::move(CommonOperandVocab),
+              std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *TII, *TRI,
+              MF->getRegInfo());
+    }
+  }
+  return createStringError(errc::invalid_argument,
+                           "No machine functions found in module");
 }
 
-Error MIR2VecVocabLegacyAnalysis::readVocabulary(VocabMap &OpcodeVocab,
-                                                 VocabMap &CommonOperandVocab,
-                                                 VocabMap &PhyRegVocabMap,
-                                                 VocabMap &VirtRegVocabMap) {
+Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab,
+                                           VocabMap &CommonOperandVocab,
+                                           VocabMap &PhyRegVocabMap,
+                                           VocabMap &VirtRegVocabMap) {
   if (VocabFile.empty())
     return createStringError(
         errc::invalid_argument,
@@ -483,49 +498,15 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary(VocabMap &OpcodeVocab,
   return Error::success();
 }
 
-Expected<mir2vec::MIRVocabulary>
-MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
-  if (Vocab.has_value())
-    return std::move(Vocab.value());
-
-  VocabMap OpcMap, CommonOperandMap, PhyRegMap, VirtRegMap;
-  if (Error Err =
-          readVocabulary(OpcMap, CommonOperandMap, PhyRegMap, VirtRegMap))
-    return std::move(Err);
-
-  // Get machine module info to access machine functions and target info
-  MachineModuleInfo &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
-
-  // Find first available machine function to get target instruction info
-  for (const auto &F : M) {
-    if (F.isDeclaration())
-      continue;
-
-    if (auto *MF = MMI.getMachineFunction(F)) {
-      auto &Subtarget = MF->getSubtarget();
-      const TargetInstrInfo *TII = Subtarget.getInstrInfo();
-      if (!TII) {
-        return createStringError(errc::invalid_argument,
-                                 "No TargetInstrInfo available; cannot create "
-                                 "MIR2Vec vocabulary");
-      }
-
-      const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
-      if (!TRI) {
-        return createStringError(errc::invalid_argument,
-                                 "No TargetRegisterInfo available; cannot "
-                                 "create MIR2Vec vocabulary");
-      }
-
-      return mir2vec::MIRVocabulary::create(
-          std::move(OpcMap), std::move(CommonOperandMap), std::move(PhyRegMap),
-          std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
-    }
-  }
+char MIR2VecVocabLegacyAnalysis::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
+                      "MIR2Vec Vocabulary Analysis", false, true)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
+                    "MIR2Vec Vocabulary Analysis", false, true)
 
-  // No machine functions available - return error
-  return createStringError(errc::invalid_argument,
-                           "No machine functions found in module");
+StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
+  return "MIR2Vec Vocabulary Analysis";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir b/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir
new file mode 100644
index 0000000000000..e5f78bfd2090e
--- /dev/null
+++ b/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir
@@ -0,0 +1,92 @@
+# REQUIRES: x86_64-linux
+# RUN: llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT
+# RUN: llvm-ir2vec embeddings --mode=mir --level=func --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL
+# RUN: llvm-ir2vec embeddings --mode=mir --level=func --function=add_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ADD
+# RUN: not llvm-ir2vec embeddings --mode=mir --level=func --function=missing_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-MISSING
+# RUN: llvm-ir2vec embeddings --mode=mir --level=bb --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL
+# RUN: llvm-ir2vec embeddings --mode=mir --level=inst --function=add_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  target triple = "x86_64-unknown-linux-gnu"
+  
+  define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %sum = add nsw i32 %a, %b
+    %result = mul nsw i32 %sum, 2
+    ret i32 %result
+  }
+  
+  define dso_local void @simple_function() {
+  entry:
+    ret void
+  }
+...
+---
+name:            add_function
+alignment:       16
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gr32 }
+  - { id: 1, class: gr32 }
+  - { id: 2, class: gr32 }
+  - { id: 3, class: gr32 }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+body:             |
+  bb.0.entry:
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags
+    %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags
+    $eax = COPY %3
+    RET 0, $eax
+
+---
+name:            simple_function
+alignment:       16
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    RET 0
+
+# CHECK-DEFAULT: MIR2Vec embeddings for machine function add_function:
+# CHECK-DEFAULT-NEXT: Function vector:  [ 26.50  27.10  27.70 ]
+# CHECK-DEFAULT: MIR2Vec embeddings for machine function simple_function:
+# CHECK-DEFAULT-NEXT: Function vector:  [ 1.10  1.20  1.30 ]
+
+# CHECK-FUNC-LEVEL: MIR2Vec embeddings for machine function add_function:
+# CHECK-FUNC-LEVEL-NEXT: Function vector:  [ 26.50  27.10  27.70 ]
+# CHECK-FUNC-LEVEL: MIR2Vec embeddings for machine function simple_function:
+# CHECK-FUNC-LEVEL-NEXT: Function vector:  [ 1.10  1.20  1.30 ]
+
+# CHECK-FUNC-LEVEL-ADD: MIR2Vec embeddings for machine function add_function:
+# CHECK-FUNC-LEVEL-ADD-NEXT: Function vector:  [ 26.50  27.10  27.70 ]
+# CHECK-FUNC-LEVEL-ADD-NOT: simple_function
+
+# CHECK-FUNC-MISSING: Error: Function 'missing_function' not found
+
+# CHECK-BB-LEVEL: MIR2Vec embeddings for machine function add_function:
+# CHECK-BB-LEVEL-NEXT: Basic block vectors:
+# CHECK-BB-LEVEL-NEXT: MBB entry:  [ 26.50  27.10  27.70 ]
+# CHECK-BB-LEVEL: MIR2Vec embeddings for machine function simple_function:
+# CHECK-BB-LEVEL-NEXT: Basic block vectors:
+# CHECK-BB-LEVEL-NEXT: MBB entry:  [ 1.10  1.20  1.30 ]
+
+# CHECK-INST-LEVEL: MIR2Vec embeddings for machine function add_function:
+# CHECK-INST-LEVEL-NEXT: Instruction vectors:
+# CHECK-INST-LEVEL: %1:gr32 = COPY $esi
+# CHECK-INST-LEVEL-NEXT:  ->  [ 6.00  6.10  6.20 ]
+# CHECK-INST-LEVEL-NEXT: %0:gr32 = COPY $edi
+# CHECK-INST-LEVEL-NEXT:  ->  [ 6.00  6.10  6.20 ]
+# CHECK-INST-LEVEL: %2:gr32 = nsw ADD32rr
+# CHECK-INST-LEVEL:  ->  [ 3.70  3.80  3.90 ]
+# CHECK-INST-LEVEL: %3:gr32 = ADD32rr
+# CHECK-INST-LEVEL:  ->  [ 3.70  3.80  3.90 ]
+# CHECK-INST-LEVEL: $eax = COPY %3:gr32
+# CHECK-INST-LEVEL-NEXT:  ->  [ 6.00  6.10  6.20 ]
+# CHECK-INST-LEVEL: RET 0, $eax
+# CHECK-INST-LEVEL-NEXT:  ->  [ 1.10  1.20  1.30 ]
diff --git a/llvm/test/tools/llvm-ir2vec/error-handling.mir b/llvm/test/tools/llvm-ir2vec/error-handling.mir
new file mode 100644
index 0000000000000..154078c18d647
--- /dev/null
+++ b/llvm/test/tools/llvm-ir2vec/error-handling.mir
@@ -0,0 +1,41 @@
+# REQUIRES: x86_64-linux
+# Test error handling and input validation for llvm-ir2vec tool in MIR mode
+
+# RUN: not llvm-ir2vec embeddings --mode=mir %s 2>&1 | FileCheck %s -check-prefix=CHECK-NO-VOCAB
+# RUN: not llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/nonexistent-vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-VOCAB-NOT-FOUND
+# RUN: not llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-INVALID-VOCAB
+# RUN: not llvm-ir2vec embeddings --mode=mir --function=nonexistent_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-NOT-FOUND
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  target triple = "x86_64-unknown-linux-gnu"
+  
+  define dso_local noundef i32 @test_function(i32 noundef %a) {
+  entry:
+    ret i32 %a
+  }
+...
+---
+name:            test_function
+alignment:       16
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gr32 }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $edi
+  
+    %0:gr32 = COPY $edi
+    $eax = COPY %0
+    RET 0, $eax
+
+# CHECK-NO-VOCAB: Error: Failed to load MIR2Vec vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+
+# CHECK-VOCAB-NOT-FOUND: Error: Failed to load MIR2Vec vocabulary
+# CHECK-VOCAB-NOT-FOUND: No such file or directory
+
+# CHECK-INVALID-VOCAB: Error: Failed to load MIR2Vec vocabulary - Missing 'Opcodes' section in vocabulary file
+
+# CHECK-FUNC-NOT-FOUND: Error: Function 'nonexistent_function' not found
diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt
index a4cf9690e86b5..e680144452136 100644
--- a/llvm/tools/llvm-ir2vec/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt
@@ -1,10 +1,24 @@
 set(LLVM_LINK_COMPONENTS
+  # Core LLVM components for IR processing
   Analysis
   Core
   IRReader
   Support
+  
+  # Machine IR components (for -mode=mir)
+  CodeGen           
+  MIRParser         
+  
+  # Target initialization (required for MIR parsing)
+  AllTargetsAsmParsers
+  AllTargetsCodeGens
+  AllTargetsDescs
+  AllTargetsInfos
   )
 
 add_llvm_tool(llvm-ir2vec
   llvm-ir2vec.cpp
+  
+  DEPENDS
+  intrinsics_gen
   )
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1031932116c1e..c41cf20539c0d 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -1,4 +1,4 @@
-//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===//
+//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,9 +7,13 @@
 //===----------------------------------------------------------------------===//
 ///
 /// \file
-/// This file implements the IR2Vec embedding generation tool.
+/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
 ///
-/// This tool provides three main subcommands:
+/// This tool supports two modes:
+/// - LLVM IR mode (-mode=llvm): Process LLVM IR
+/// - Machine IR mode (-mode=mir): Process Machine IR
+///
+/// Available subcommands:
 ///
 /// 1. Triplet Generation (triplets):
 ///    Generates numeric triplets (head, tail, relation) for vocabulary
@@ -23,16 +27,24 @@
 ///    Usage: llvm-ir2vec entities input.bc -o entity2id.txt
 ///
 /// 3. Embedding Generation (embeddings):
-///    Generates IR2Vec embeddings using a trained vocabulary.
-///    Usage: llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
-///    --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
-///    Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
+///    Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
+///
+///    For LLVM IR:
+///      llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
+///        --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
+///      Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
+///
+///    For Machine IR:
+///      llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
+///        --level=<level> input.mir -o embeddings.txt
+///
 ///    Levels: --level=inst (instructions), --level=bb (basic blocks),
-///    --level=func (functions) (See IR2Vec.cpp for more embedding generation
-///    options)
+///    --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
+///    generation options)
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
@@ -50,10 +62,36 @@
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/WithColor.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+
 #define DEBUG_TYPE "ir2vec"
 
 namespace llvm {
-namespace ir2vec {
+
+// Common option category for options shared between IR2Vec and MIR2Vec
+static cl::OptionCategory CommonCategory("Common Options",
+                                         "Options applicable to both IR2Vec "
+                                         "and MIR2Vec modes");
+
+enum IRKind {
+  LLVMIR = 0, ///< LLVM IR
+  MIR         ///< Machine IR
+};
+
+static cl::opt<IRKind>
+    IRMode("mode", cl::desc("Tool operation mode:"),
+           cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
+                      clEnumValN(MIR, "mir", "Process Machine IR")),
+           cl::init(LLVMIR), cl::cat(CommonCategory));
 
 // Subcommands
 static cl::SubCommand
@@ -70,18 +108,18 @@ static cl::opt<std::string>
     InputFilename(cl::Positional,
                   cl::desc("<input bitcode file or '-' for stdin>"),
                   cl::init("-"), cl::sub(TripletsSubCmd),
-                  cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory));
+                  cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
 
 static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
                                            cl::value_desc("filename"),
                                            cl::init("-"),
-                                           cl::cat(ir2vec::IR2VecCategory));
+                                           cl::cat(CommonCategory));
 
 // Embedding-specific options
 static cl::opt<std::string>
     FunctionName("function", cl::desc("Process specific function only"),
                  cl::value_desc("name"), cl::Optional, cl::init(""),
-                 cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory));
+                 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
 
 enum EmbeddingLevel {
   InstructionLevel, // Generate instruction-level embeddings
@@ -98,9 +136,9 @@ static cl::opt<EmbeddingLevel>
                      clEnumValN(FunctionLevel, "func",
                                 "Generate function-level embeddings")),
           cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
-          cl::cat(ir2vec::IR2VecCategory));
+          cl::cat(CommonCategory));
 
-namespace {
+namespace ir2vec {
 
 /// Relation types for triplet generation
 enum RelationType {
@@ -300,20 +338,116 @@ Error processModule(Module &M, raw_ostream &OS) {
   }
   return Error::success();
 }
-} // namespace
 } // namespace ir2vec
+
+namespace mir2vec {
+
+/// Helper class for MIR2Vec embedding generation
+class MIR2VecTool {
+private:
+  MachineModuleInfo &MMI;
+  std::unique_ptr<MIRVocabulary> Vocab;
+
+public:
+  explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
+
+  /// Initialize the MIR2Vec vocabulary
+  bool initializeVocabulary(const Module &M) {
+    MIR2VecVocabProvider Provider(MMI);
+    auto VocabOrErr = Provider.getVocabulary(M);
+    if (!VocabOrErr) {
+      errs() << "Error: Failed to load MIR2Vec vocabulary - "
+             << toString(VocabOrErr.takeError()) << "\n";
+      return false;
+    }
+    Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+    return true;
+  }
+
+  /// Generate embeddings for all machine functions in the module
+  void generateEmbeddings(const Module &M, raw_ostream &OS) const {
+    if (!Vocab) {
+      OS << "Error: Vocabulary not initialized.\n";
+      return;
+    }
+
+    for (const Function &F : M) {
+      if (F.isDeclaration())
+        continue;
+
+      MachineFunction *MF = MMI.getMachineFunction(F);
+      if (!MF) {
+        errs() << "Warning: No MachineFunction for " << F.getName() << "\n";
+        continue;
+      }
+
+      generateEmbeddings(*MF, OS);
+    }
+  }
+
+  /// Generate embeddings for a specific machine function
+  void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
+    if (!Vocab) {
+      OS << "Error: Vocabulary not initialized.\n";
+      return;
+    }
+
+    auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
+    if (!Emb) {
+      errs() << "Error: Failed to create embedder for " << MF.getName() << "\n";
+      return;
+    }
+
+    OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+
+    // Generate embeddings based on the specified level
+    switch (Level) {
+    case FunctionLevel: {
+      OS << "Function vector: ";
+      Emb->getMFunctionVector().print(OS);
+      break;
+    }
+    case BasicBlockLevel: {
+      OS << "Basic block vectors:\n";
+      for (const MachineBasicBlock &MBB : MF) {
+        OS << "MBB " << MBB.getName() << ": ";
+        Emb->getMBBVector(MBB).print(OS);
+      }
+      break;
+    }
+    case InstructionLevel: {
+      OS << "Instruction vectors:\n";
+      for (const MachineBasicBlock &MBB : MF) {
+        for (const MachineInstr &MI : MBB) {
+          OS << MI << " -> ";
+          Emb->getMInstVector(MI).print(OS);
+        }
+      }
+      break;
+    }
+    }
+  }
+
+  const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
+};
+
+} // namespace mir2vec
+
 } // namespace llvm
 
 int main(int argc, char **argv) {
   using namespace llvm;
   using namespace llvm::ir2vec;
+  using namespace llvm::mir2vec;
 
   InitLLVM X(argc, argv);
-  cl::HideUnrelatedOptions(ir2vec::IR2VecCategory);
+  // Show Common, IR2Vec and MIR2Vec option categories
+  cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{
+      &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
   cl::ParseCommandLineOptions(
       argc, argv,
-      "IR2Vec - Embedding Generation Tool\n"
-      "Generates embeddings for a given LLVM IR and "
+      "IR2Vec/MIR2Vec - Embedding Generation Tool\n"
+      "Generates embeddings for a given LLVM IR or MIR and "
       "supports triplet generation for vocabulary "
       "training and embedding generation.\n\n"
       "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
@@ -326,26 +460,110 @@ int main(int argc, char **argv) {
     return 1;
   }
 
-  if (EntitiesSubCmd) {
-    // Just dump entity mappings without processing any IR
-    IR2VecTool::generateEntityMappings(OS);
+  if (IRMode == IRKind::LLVMIR) {
+    if (EntitiesSubCmd) {
+      // Just dump entity mappings without processing any IR
+      IR2VecTool::generateEntityMappings(OS);
+      return 0;
+    }
+
+    // Parse the input LLVM IR file or stdin
+    SMDiagnostic Err;
+    LLVMContext Context;
+    std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
+    if (!M) {
+      Err.print(argv[0], errs());
+      return 1;
+    }
+
+    if (Error Err = processModule(*M, OS)) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
+        errs() << "Error: " << EIB.message() << "\n";
+      });
+      return 1;
+    }
     return 0;
   }
+  if (IRMode == IRKind::MIR) {
+    // Initialize targets for Machine IR processing
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+    InitializeAllAsmParsers();
+    InitializeAllAsmPrinters();
+    static codegen::RegisterCodeGenFlags CGF;
+
+    // Parse MIR input file
+    SMDiagnostic Err;
+    LLVMContext Context;
+    std::unique_ptr<TargetMachine> TM;
+
+    auto MIR = createMIRParserFromFile(InputFilename, Err, Context);
+    if (!MIR) {
+      Err.print(argv[0], WithColor::error(errs(), argv[0]));
+      return 1;
+    }
 
-  // Parse the input LLVM IR file or stdin
-  SMDiagnostic Err;
-  LLVMContext Context;
-  std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
-  if (!M) {
-    Err.print(argv[0], errs());
-    return 1;
-  }
+    auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
+                             StringRef OldDLStr) -> std::optional<std::string> {
+      std::string IRTargetTriple = DataLayoutTargetTriple.str();
+      Triple TheTriple = Triple(IRTargetTriple);
+      if (TheTriple.getTriple().empty())
+        TheTriple.setTriple(sys::getDefaultTargetTriple());
+      auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
+      if (!TMOrErr) {
+        Err.print(argv[0], WithColor::error(errs(), argv[0]));
+        exit(1);
+      }
+      TM = std::move(*TMOrErr);
+      return TM->createDataLayout().getStringRepresentation();
+    };
+
+    std::unique_ptr<Module> M = MIR->parseIRModule(SetDataLayout);
+    if (!M) {
+      Err.print(argv[0], WithColor::error(errs(), argv[0]));
+      return 1;
+    }
 
-  if (Error Err = processModule(*M, OS)) {
-    handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
-      errs() << "Error: " << EIB.message() << "\n";
-    });
-    return 1;
+    // Parse machine functions
+    auto MMI = std::make_unique<MachineModuleInfo>(TM.get());
+    if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) {
+      Err.print(argv[0], WithColor::error(errs(), argv[0]));
+      return 1;
+    }
+
+    // Create MIR2Vec tool and initialize vocabulary
+    MIR2VecTool Tool(*MMI);
+    if (!Tool.initializeVocabulary(*M))
+      return 1;
+
+    LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+                      << "Vocabulary dimension: "
+                      << Tool.getVocabulary()->getDimension() << "\n"
+                      << "Vocabulary size: "
+                      << Tool.getVocabulary()->getCanonicalSize() << "\n");
+
+    // Generate embeddings based on subcommand
+    if (!FunctionName.empty()) {
+      // Process single function
+      Function *F = M->getFunction(FunctionName);
+      if (!F) {
+        errs() << "Error: Function '" << FunctionName << "' not found\n";
+        return 1;
+      }
+
+      MachineFunction *MF = MMI->getMachineFunction(*F);
+      if (!MF) {
+        errs() << "Error: No MachineFunction for " << FunctionName << "\n";
+        return 1;
+      }
+
+      Tool.generateEmbeddings(*MF, OS);
+    } else {
+      // Process all functions
+      Tool.generateEmbeddings(*M, OS);
+    }
+
+    return 0;
   }
 
   return 0;



More information about the llvm-commits mailing list