[llvm] [MIR2Vec] Add embedder for machine instructions (PR #162161)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 9 15:20:50 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162161

>From 19bdf147ecf26c985f94d024c17d4818c9db775d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 6 Oct 2025 21:15:14 +0000
Subject: [PATCH] MIR2Vec embedding

---
 llvm/include/llvm/CodeGen/MIR2Vec.h           | 113 ++++++-
 llvm/include/llvm/CodeGen/Passes.h            |   4 +
 llvm/include/llvm/InitializePasses.h          |   1 +
 llvm/lib/CodeGen/CodeGen.cpp                  |   1 +
 llvm/lib/CodeGen/MIR2Vec.cpp                  | 155 ++++++++-
 .../Inputs/mir2vec_dummy_3D_vocab.json        |  22 ++
 llvm/test/CodeGen/MIR2Vec/if-else.mir         | 144 +++++++++
 .../MIR2Vec/mir2vec-basic-symbolic.mir        |  76 +++++
 llvm/tools/llc/llc.cpp                        |  15 +
 llvm/unittests/CodeGen/MIR2VecTest.cpp        | 295 ++++++++++++++++--
 10 files changed, 797 insertions(+), 29 deletions(-)
 create mode 100644 llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
 create mode 100644 llvm/test/CodeGen/MIR2Vec/if-else.mir
 create mode 100644 llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 7b1b5d9aee15d..f6b0571f5dac6 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -52,11 +52,21 @@ class LLVMContext;
 class MIR2VecVocabLegacyAnalysis;
 class TargetInstrInfo;
 
+enum class MIR2VecKind { Symbolic };
+
 namespace mir2vec {
+
+// Forward declarations
+class MIREmbedder;
+class SymbolicMIREmbedder;
+
 extern llvm::cl::OptionCategory MIR2VecCategory;
 extern cl::opt<float> OpcWeight;
 
 using Embedding = ir2vec::Embedding;
+using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
+using MachineBlockEmbeddingsMap =
+    DenseMap<const MachineBasicBlock *, Embedding>;
 
 /// Class for storing and accessing the MIR2Vec vocabulary.
 /// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -107,19 +117,91 @@ class MIRVocabulary {
 
   const_iterator end() const { return Storage.end(); }
 
-  /// Total number of entries in the vocabulary
-  size_t getCanonicalSize() const { return Storage.size(); }
-
   MIRVocabulary() = delete;
 
   /// Factory method to create MIRVocabulary from vocabulary map
   static Expected<MIRVocabulary> create(VocabMap &&Entries,
                                         const TargetInstrInfo &TII);
 
+  /// Create a dummy vocabulary for testing purposes.
+  static Expected<MIRVocabulary>
+  createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
+
+  /// Total number of entries in the vocabulary
+  size_t getCanonicalSize() const { return Storage.size(); }
+
 private:
   MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
 };
 
+/// Base class for MIR embedders
+class MIREmbedder {
+protected:
+  const MachineFunction &MF;
+  const MIRVocabulary &Vocab;
+
+  /// Dimension of the embeddings; Captured from the vocabulary
+  const unsigned Dimension;
+
+  /// Weight for opcode embeddings
+  const float OpcWeight;
+
+  MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
+      : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
+        OpcWeight(mir2vec::OpcWeight) {}
+
+  /// Function to compute embeddings.
+  Embedding computeEmbeddings() const;
+
+  /// Function to compute the embedding for a given machine basic block.
+  Embedding computeEmbeddings(const MachineBasicBlock &MBB) const;
+
+  /// Function to compute the embedding for a given machine instruction.
+  /// Specific to the kind of embeddings being computed.
+  virtual Embedding computeEmbeddings(const MachineInstr &MI) const = 0;
+
+public:
+  virtual ~MIREmbedder() = default;
+
+  /// Factory method to create an Embedder object of the specified kind
+  /// Returns nullptr if the requested kind is not supported.
+  static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
+                                             const MachineFunction &MF,
+                                             const MIRVocabulary &Vocab);
+
+  /// Computes and returns the embedding for a given machine instruction MI in
+  /// the machine function MF.
+  Embedding getMInstVector(const MachineInstr &MI) const {
+    return computeEmbeddings(MI);
+  }
+
+  /// Computes and returns the embedding for a given machine basic block in the
+  /// machine function MF.
+  Embedding getMBBVector(const MachineBasicBlock &MBB) const {
+    return computeEmbeddings(MBB);
+  }
+
+  /// Computes and returns the embedding for the current machine function.
+  Embedding getMFunctionVector() const {
+    // Currently, we always (re)compute the embeddings for the function. This is
+    // cheaper than caching the vector.
+    return computeEmbeddings();
+  }
+};
+
+/// Class for computing Symbolic embeddings
+/// Symbolic embeddings are constructed based on the entity-level
+/// representations obtained from the MIR Vocabulary.
+class SymbolicMIREmbedder : public MIREmbedder {
+private:
+  Embedding computeEmbeddings(const MachineInstr &MI) const override;
+
+public:
+  SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
+  static std::unique_ptr<SymbolicMIREmbedder>
+  create(const MachineFunction &MF, const MIRVocabulary &Vocab);
+};
+
 } // namespace mir2vec
 
 /// Pass to analyze and populate MIR2Vec vocabulary from a module
@@ -166,6 +248,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
   }
 };
 
+/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
+/// and instructions
+class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
+  raw_ostream &OS;
+
+public:
+  static char ID;
+  explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
+      : MachineFunctionPass(ID), OS(OS) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MIR2VecVocabLegacyAnalysis>();
+    AU.setPreservesAll();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override {
+    return "MIR2Vec Embedder Printer Pass";
+  }
+};
+
+/// Create a machine pass that prints MIR2Vec embeddings
+MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_MIR2VEC_H
\ No newline at end of file
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 272b4acf950c5..7fae550d8d170 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
 LLVM_ABI MachineFunctionPass *
 createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
 
+/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
+/// machine functions, basic blocks and instructions.
+LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 /// StackFramePrinter pass - This pass prints out the machine function's
 /// stack frame to the given stream as a debugging tool.
 LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index cd774e7888e64..d507ba267d791 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -222,6 +222,7 @@ LLVM_ABI void
 initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
 LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp
index c438eaeb29d1e..9795a0b707fd3 100644
--- a/llvm/lib/CodeGen/CodeGen.cpp
+++ b/llvm/lib/CodeGen/CodeGen.cpp
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
   initializeMachineUniformityAnalysisPassPass(Registry);
   initializeMIR2VecVocabLegacyAnalysisPass(Registry);
   initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
+  initializeMIR2VecPrinterLegacyPassPass(Registry);
   initializeMachineUniformityInfoPrinterPassPass(Registry);
   initializeMachineVerifierLegacyPassPass(Registry);
   initializeObjCARCContractLegacyPassPass(Registry);
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index e85976547a2c2..2df14a75bf623 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
 #include "llvm/IR/Module.h"
@@ -41,11 +42,18 @@ static cl::opt<std::string>
 cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
                          cl::desc("Weight for machine opcode embeddings"),
                          cl::cat(MIR2VecCategory));
+cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
+    "mir2vec-kind", cl::Optional,
+    cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
+                          "Generate symbolic embeddings for MIR")),
+    cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
+    cl::cat(MIR2VecCategory));
+
 } // namespace mir2vec
 } // namespace llvm
 
 //===----------------------------------------------------------------------===//
-// Vocabulary Implementation
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -191,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
                     << " unique base opcodes\n");
 }
 
+Expected<MIRVocabulary>
+MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
+                                       unsigned Dim) {
+  assert(Dim > 0 && "Dimension must be greater than zero");
+
+  float DummyVal = 0.1f;
+
+  // Create a temporary vocabulary instance to build canonical mapping
+  MIRVocabulary TempVocab({}, TII);
+  TempVocab.buildCanonicalOpcodeMapping();
+
+  // Create dummy embeddings for all canonical opcode names
+  VocabMap DummyVocabMap;
+  for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
+    // Create dummy embedding filled with DummyVal
+    Embedding DummyEmbedding(Dim, DummyVal);
+    DummyVocabMap[COpcodeName] = DummyEmbedding;
+    DummyVal += 0.1f;
+  }
+
+  // Create and return vocabulary with dummy embeddings
+  return MIRVocabulary::create(std::move(DummyVocabMap), TII);
+}
+
 //===----------------------------------------------------------------------===//
 // MIR2VecVocabLegacyAnalysis Implementation
 //===----------------------------------------------------------------------===//
@@ -261,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printer Passes Implementation
+// MIREmbedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
+                                                 const MachineFunction &MF,
+                                                 const MIRVocabulary &Vocab) {
+  switch (Mode) {
+  case MIR2VecKind::Symbolic:
+    return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+  }
+  return nullptr;
+}
+
+Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
+  Embedding MBBVector(Dimension, 0);
+
+  // Get instruction info for opcode name resolution
+  const auto &Subtarget = MF.getSubtarget();
+  const auto *TII = Subtarget.getInstrInfo();
+  if (!TII) {
+    MF.getFunction().getContext().emitError(
+        "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
+    return MBBVector;
+  }
+
+  // Process each machine instruction in the basic block
+  for (const auto &MI : MBB) {
+    // Skip debug instructions and other metadata
+    if (MI.isDebugInstr())
+      continue;
+    MBBVector += computeEmbeddings(MI);
+  }
+
+  return MBBVector;
+}
+
+Embedding MIREmbedder::computeEmbeddings() const {
+  Embedding MFuncVector(Dimension, 0);
+
+  // Consider all reachable machine basic blocks in the function
+  for (const auto *MBB : depth_first(&MF))
+    MFuncVector += computeEmbeddings(*MBB);
+  return MFuncVector;
+}
+
+SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
+                                         const MIRVocabulary &Vocab)
+    : MIREmbedder(MF, Vocab) {}
+
+std::unique_ptr<SymbolicMIREmbedder>
+SymbolicMIREmbedder::create(const MachineFunction &MF,
+                            const MIRVocabulary &Vocab) {
+  return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+}
+
+Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
+  // Skip debug instructions and other metadata
+  if (MI.isDebugInstr())
+    return Embedding(Dimension, 0);
+
+  // Todo: Add operand/argument contributions
+
+  return Vocab[MI.getOpcode()];
+}
+
+//===----------------------------------------------------------------------===//
+// Printer Passes
 //===----------------------------------------------------------------------===//
 
 char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -300,3 +398,56 @@ MachineFunctionPass *
 llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
   return new MIR2VecVocabPrinterLegacyPass(OS);
 }
+
+char MIR2VecPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                      "MIR2Vec Embedder Printer Pass", false, true)
+INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                    "MIR2Vec Embedder Printer Pass", false, true)
+
+bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
+  auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
+  auto VocabOrErr =
+      Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
+  assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
+  auto &MIRVocab = *VocabOrErr;
+
+  auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
+  if (!Emb) {
+    OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
+       << "\n";
+    return false;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+  OS << "Machine Function vector: ";
+  Emb->getMFunctionVector().print(OS);
+
+  OS << "Machine basic block vectors:\n";
+  for (const MachineBasicBlock &MBB : MF) {
+    OS << "Machine basic block: " << MBB.getFullName() << ":\n";
+    Emb->getMBBVector(MBB).print(OS);
+  }
+
+  OS << "Machine instruction vectors:\n";
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions as they are not
+      // embedded
+      if (MI.isDebugInstr())
+        continue;
+
+      OS << "Machine instruction: ";
+      MI.print(OS);
+      Emb->getMInstVector(MI).print(OS);
+    }
+  }
+
+  return false;
+}
+
+MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
+  return new MIR2VecPrinterLegacyPass(OS);
+}
diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
new file mode 100644
index 0000000000000..5de715bf80917
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
@@ -0,0 +1,22 @@
+{
+  "entities": {
+    "KILL": [0.1, 0.2, 0.3],
+    "MOV": [0.4, 0.5, 0.6],
+    "LEA": [0.7, 0.8, 0.9],
+    "RET": [1.0, 1.1, 1.2],
+    "ADD": [1.3, 1.4, 1.5],
+    "SUB": [1.6, 1.7, 1.8],
+    "IMUL": [1.9, 2.0, 2.1],
+    "AND": [2.2, 2.3, 2.4],
+    "OR": [2.5, 2.6, 2.7],
+    "XOR": [2.8, 2.9, 3.0],
+    "CMP": [3.1, 3.2, 3.3],
+    "TEST": [3.4, 3.5, 3.6],
+    "JMP": [3.7, 3.8, 3.9],
+    "CALL": [4.0, 4.1, 4.2],
+    "PUSH": [4.3, 4.4, 4.5],
+    "POP": [4.6, 4.7, 4.8],
+    "NOP": [4.9, 5.0, 5.1],
+    "COPY": [5.2, 5.3, 5.4]
+  }
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir
new file mode 100644
index 0000000000000..2accf476f7c4d
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir
@@ -0,0 +1,144 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  
+  define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %retval = alloca i32, align 4
+    %a.addr = alloca i32, align 4
+    %b.addr = alloca i32, align 4
+    store i32 %a, ptr %a.addr, align 4
+    store i32 %b, ptr %b.addr, align 4
+    %0 = load i32, ptr %a.addr, align 4
+    %1 = load i32, ptr %b.addr, align 4
+    %cmp = icmp sgt i32 %0, %1
+    br i1 %cmp, label %if.then, label %if.else
+  
+  if.then:                                          ; preds = %entry
+    %2 = load i32, ptr %b.addr, align 4
+    store i32 %2, ptr %retval, align 4
+    br label %return
+  
+  if.else:                                          ; preds = %entry
+    %3 = load i32, ptr %a.addr, align 4
+    store i32 %3, ptr %retval, align 4
+    br label %return
+  
+  return:                                           ; preds = %if.else, %if.then
+    %4 = load i32, ptr %retval, align 4
+    ret i32 %4
+  }
+...
+---
+name:            abc
+alignment:       16
+exposesReturnsTwice: false
+legalized:       false
+regBankSelected: false
+selected:        false
+failedISel:      false
+tracksRegLiveness: true
+hasWinCFI:       false
+noPhis:          false
+isSSA:           true
+noVRegs:         false
+hasFakeUses:     false
+callsEHReturn:   false
+callsUnwindInit: false
+hasEHContTarget: false
+hasEHScopes:     false
+hasEHFunclets:   false
+isOutlined:      false
+debugInstrRef:   true
+failsVerification: false
+tracksDebugUserValues: false
+registers:
+  - { id: 0, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 1, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 2, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 3, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 4, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 5, class: gr32, preferred-register: '', flags: [  ] }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+frameInfo:
+  isFrameAddressTaken: false
+  isReturnAddressTaken: false
+  hasStackMap:     false
+  hasPatchPoint:   false
+  stackSize:       0
+  offsetAdjustment: 0
+  maxAlignment:    4
+  adjustsStack:    false
+  hasCalls:        false
+  stackProtector:  ''
+  functionContext: ''
+  maxCallFrameSize: 4294967295
+  cvBytesOfCalleeSavedRegisters: 0
+  hasOpaqueSPAdjustment: false
+  hasVAStart:      false
+  hasMustTailInVarArgFunc: false
+  hasTailCall:     false
+  isCalleeSavedInfoValid: false
+  localFrameSize:  0
+fixedStack:      []
+stack:
+  - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+entry_values:    []
+callSites:       []
+debugValueSubstitutions: []
+constants:       []
+machineFunctionInfo:
+  amxProgModel:    None
+body:             |
+  bb.0.entry:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr)
+    MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr)
+    %2:gr32 = SUB32rr %0, %1, implicit-def $eflags
+    JCC_1 %bb.2, 14, implicit $eflags
+    JMP_1 %bb.1
+  
+  bb.1.if.then:
+    successors: %bb.3(0x80000000)
+  
+    %4:gr32 = MOV32rm %stack.2.b.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.b.addr)
+    MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %4 :: (store (s32) into %ir.retval)
+    JMP_1 %bb.3
+  
+  bb.2.if.else:
+    successors: %bb.3(0x80000000)
+  
+    %3:gr32 = MOV32rm %stack.1.a.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.a.addr)
+    MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %3 :: (store (s32) into %ir.retval)
+  
+  bb.3.return:
+    %5:gr32 = MOV32rm %stack.0.retval, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.retval)
+    $eax = COPY %5
+    RET 0, $eax
+...
+
+# CHECK: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: abc:entry:
+# CHECK-NEXT:  [ 16.50  17.10  17.70 ]
+# CHECK-NEXT: Machine basic block: abc:if.then:
+# CHECK-NEXT:  [ 4.50  4.80  5.10 ]
+# CHECK-NEXT: Machine basic block: abc:if.else:
+# CHECK-NEXT:  [ 0.80  1.00  1.20 ]
+# CHECK-NEXT: Machine basic block: abc:return:
+# CHECK-NEXT:  [ 6.60  6.90  7.20 ]
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
new file mode 100644
index 0000000000000..44240affb2206
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
@@ -0,0 +1,76 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  
+  define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %sum = add nsw i32 %a, %b
+    %result = mul nsw i32 %sum, 2
+    ret i32 %result
+  }
+  
+  define dso_local void @simple_function() {
+  entry:
+    ret void
+  }
+...
+---
+name:            add_function
+alignment:       16
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gr32 }
+  - { id: 1, class: gr32 }
+  - { id: 2, class: gr32 }
+  - { id: 3, class: gr32 }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+body:             |
+  bb.0.entry:
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags
+    %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags
+    $eax = COPY %3
+    RET 0, $eax
+
+---
+name:            simple_function
+alignment:       16
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    RET 0
+
+# CHECK: MIR2Vec embeddings for machine function add_function:
+# CHECK: Function vector:  [ 19.20  19.80  20.40 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: add_function:entry:
+# CHECK-NEXT:  [ 19.20  19.80  20.40 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: %1:gr32 = COPY $esi
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: %0:gr32 = COPY $edi
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: %2:gr32 = nsw ADD32rr %0:gr32(tied-def 0), %1:gr32, implicit-def dead $eflags
+# CHECK-NEXT:  [ 1.30  1.40  1.50 ]
+# CHECK-NEXT: Machine instruction: %3:gr32 = ADD32rr %2:gr32(tied-def 0), %2:gr32, implicit-def dead $eflags
+# CHECK-NEXT:  [ 1.30  1.40  1.50 ]
+# CHECK-NEXT: Machine instruction: $eax = COPY %3:gr32
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: RET 0, $eax
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
+
+# CHECK: MIR2Vec embeddings for machine function simple_function:
+# CHECK-NEXT:Function vector:  [ 1.00  1.10  1.20 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: simple_function:entry:
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: RET 0
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
\ No newline at end of file
diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp
index f04b256e2e6c9..f4441ccb896b1 100644
--- a/llvm/tools/llc/llc.cpp
+++ b/llvm/tools/llc/llc.cpp
@@ -172,6 +172,11 @@ static cl::opt<bool>
                       cl::desc("Print MIR2Vec vocabulary contents"),
                       cl::init(false));
 
+static cl::opt<bool>
+    PrintMIR2Vec("print-mir2vec", cl::Hidden,
+                 cl::desc("Print MIR2Vec embeddings for functions"),
+                 cl::init(false));
+
 static cl::list<std::string> IncludeDirs("I", cl::desc("include search path"));
 
 static cl::opt<bool> RemarksWithHotness(
@@ -776,6 +781,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
         PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
       }
 
+      // Add MIR2Vec printer if requested
+      if (PrintMIR2Vec) {
+        PM.add(createMIR2VecPrinterLegacyPass(errs()));
+      }
+
       PM.add(createFreeMachineFunctionPass());
     } else {
       if (Target->addPassesToEmitFile(PM, *OS, DwoOut ? &DwoOut->os() : nullptr,
@@ -789,6 +799,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
       if (PrintMIR2VecVocab) {
         PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
       }
+
+      // Add MIR2Vec printer if requested
+      if (PrintMIR2Vec) {
+        PM.add(createMIR2VecPrinterLegacyPass(errs()));
+      }
     }
 
     Target->getObjFileLowering()->Initialize(MMIWP->getMMI().getContext(),
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 11222b4d02fa3..8cd9d5ac9f6be 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -82,6 +82,9 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
       return;
     }
 
+    // Set the data layout to match the target machine
+    M->setDataLayout(TM->createDataLayout());
+
     // Create a dummy function to get subtarget info
     FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
     Function *F =
@@ -96,16 +99,27 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
   }
 
   void TearDown() override { TII = nullptr; }
-};
 
-// Function to find an opcode by name
-static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
-  for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
-    if (TII->getName(Opcode) == Name)
-      return Opcode;
+  // Find an opcode by name
+  int findOpcodeByName(StringRef Name) {
+    for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+      if (TII->getName(Opcode) == Name)
+        return Opcode;
+    }
+    return -1; // Not found
   }
-  return -1; // Not found
-}
+
+  // Create a vocabulary with specific opcodes and embeddings
+  Expected<MIRVocabulary>
+  createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes,
+                  unsigned dimension = 2) {
+    assert(TII && "TargetInstrInfo not initialized");
+    VocabMap VMap;
+    for (const auto &[name, value] : opcodes)
+      VMap[name] = Embedding(dimension, value);
+    return MIRVocabulary::create(std::move(VMap), *TII);
+  }
+};
 
 TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   // Test that same base opcodes get same canonical indices
@@ -118,10 +132,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
 
   // Create a MIRVocabulary instance to test the mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VMap;
   Embedding Val = Embedding(64, 1.0f);
-  VMap["ADD"] = Val;
-  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
   ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
       << "Failed to create vocabulary: "
       << toString(TestVocabOrErr.takeError());
@@ -156,16 +168,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
             6880u); // X86 has >6880 unique base opcodes
 
   // Check that the embeddings for opcodes not in the vocab are zero vectors
-  int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+  int Add32rrOpcode = findOpcodeByName("ADD32rr");
   ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
   EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
 
-  int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+  int Sub32rrOpcode = findOpcodeByName("SUB32rr");
   ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
   EXPECT_TRUE(
       TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
 
-  int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+  int Mov32rrOpcode = findOpcodeByName("MOV32rr");
   ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
   EXPECT_TRUE(
       TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
@@ -178,9 +190,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
   // Create a MIRVocabulary instance to test deterministic mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VMap;
-  VMap["ADD"] = Embedding(64, 1.0f);
-  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
   ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
       << "Failed to create vocabulary: "
       << toString(TestVocabOrErr.takeError());
@@ -189,8 +199,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
-
-  EXPECT_EQ(Index1, Index2);
   EXPECT_EQ(Index2, Index3);
 
   // Test across multiple runs
@@ -202,11 +210,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
 // Test MIRVocabulary construction
 TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
-  VocabMap VMap;
-  VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
-  VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
-
-  auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128);
   ASSERT_TRUE(static_cast<bool>(VocabOrErr))
       << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
   auto &Vocab = *VocabOrErr;
@@ -243,4 +247,247 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
   }
 }
 
+// Fixture for embedding related tests
+class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
+protected:
+  std::unique_ptr<MachineModuleInfo> MMI;
+  MachineFunction *MF = nullptr;
+
+  void SetUp() override {
+    MIR2VecVocabTestFixture::SetUp();
+
+    // Create a dummy function for MachineFunction
+    FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
+    Function *F =
+        Function::Create(FT, Function::ExternalLinkage, "test", M.get());
+
+    MMI = std::make_unique<MachineModuleInfo>(TM.get());
+    MF = &MMI->getOrCreateMachineFunction(*F);
+  }
+
+  void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
+
+  // Create a machine instruction
+  MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
+    const MCInstrDesc &Desc = TII->get(Opcode);
+    // Create instruction - operands don't affect opcode-based embeddings
+    MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
+    return MI;
+  }
+
+  MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
+                                   const char *OpcodeName) {
+    int Opcode = findOpcodeByName(OpcodeName);
+    if (Opcode == -1)
+      return nullptr;
+    return createMachineInstr(MBB, Opcode);
+  }
+
+  void createMachineInstrs(MachineBasicBlock &MBB,
+                           std::initializer_list<const char *> Opcodes) {
+    for (const char *OpcodeName : Opcodes) {
+      MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
+      ASSERT_TRUE(MI != nullptr);
+    }
+  }
+};
+
+// Test factory method for creating embedder
+TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
+  auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &V = *VocabOrErr;
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
+  EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
+  auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &V = *VocabOrErr;
+  auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
+  EXPECT_FALSE(static_cast<bool>(Result));
+}
+
+// Test SymbolicMIREmbedder with simple target opcodes
+TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
+  // Create a test vocabulary with specific values
+  auto VocabOrErr = createTestVocab(
+      {
+          {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
+          {"RET", 2.0f},  // [2.0, 2.0, 2.0, 2.0]
+          {"TRAP", 3.0f}  // [3.0, 3.0, 3.0, 3.0]
+      },
+      4);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &Vocab = *VocabOrErr;
+  // Create a basic block using fixture's MF
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Use real X86 opcodes that should exist and not be pseudo
+  auto NoopInst = createMachineInstr(*MBB, "NOOP");
+  ASSERT_TRUE(NoopInst != nullptr);
+
+  auto RetInst = createMachineInstr(*MBB, "RET64");
+  ASSERT_TRUE(RetInst != nullptr);
+
+  auto TrapInst = createMachineInstr(*MBB, "TRAP");
+  ASSERT_TRUE(TrapInst != nullptr);
+
+  // Verify these are not pseudo instructions
+  ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
+  ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
+  ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test instruction embeddings
+  auto NoopEmb = Embedder->getMInstVector(*NoopInst);
+  auto RetEmb = Embedder->getMInstVector(*RetInst);
+  auto TrapEmb = Embedder->getMInstVector(*TrapInst);
+
+  // Verify embeddings match expected values (accounting for weight scaling)
+  float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line
+  EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
+  EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
+  EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
+
+  // Test basic block embedding (should be sum of instruction embeddings)
+  auto MBBVector = Embedder->getMBBVector(*MBB);
+
+  // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
+  // weight = [6, 6, 6, 6] * weight
+  Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
+  EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector));
+
+  // Test function embedding (should equal MBB embedding since we have one MBB)
+  auto MFuncVector = Embedder->getMFunctionVector();
+  EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
+}
+
+// Test embedder with multiple basic blocks
+TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
+  // Create a test vocabulary
+  auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}});
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &Vocab = *VocabOrErr;
+
+  // Create two basic blocks using fixture's MF
+  MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
+  MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB1);
+  MF->push_back(MBB2);
+
+  createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
+  createMachineInstr(*MBB2, "TRAP");
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test basic block embeddings
+  auto MBB1Vector = Embedder->getMBBVector(*MBB1);
+  auto MBB2Vector = Embedder->getMBBVector(*MBB2);
+
+  float ExpectedWeight = mir2vec::OpcWeight;
+  // BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
+  Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight);
+  EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector));
+
+  // BB2: TRAP = [2, 2] * weight
+  Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight);
+  EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector));
+
+  // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
+  // Function embedding should be just the first BB embedding as the second BB
+  // is unreachable
+  auto MFuncVector = Embedder->getMFunctionVector();
+  EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector));
+
+  // Add a branch from BB1 to BB2 to make both reachable; now function embedding
+  // should be MBB1 + MBB2
+  MBB1->addSuccessor(MBB2);
+  auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings
+  Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector;
+  EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector));
+}
+
+// Test embedder with empty basic block
+TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
+
+  // Create an empty basic block
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Create embedder
+  auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &V = *VocabOrErr;
+  auto Embedder = SymbolicMIREmbedder::create(*MF, V);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test that empty BB has zero embedding
+  auto MBBVector = Embedder->getMBBVector(*MBB);
+  Embedding ExpectedBBVector(2, 0.0f);
+  EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+
+  // Function embedding should also be zero
+  auto MFuncVector = Embedder->getMFunctionVector();
+  EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test embedder with opcodes not in vocabulary
+TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
+  // Create a test vocabulary with limited entries
+  // SUB is intentionally not included
+  auto VocabOrErr = createTestVocab({{"ADD", 1.0f}});
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &Vocab = *VocabOrErr;
+
+  // Create a basic block
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Find opcodes
+  int AddOpcode = findOpcodeByName("ADD32rr");
+  int SubOpcode = findOpcodeByName("SUB32rr");
+
+  ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
+  ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
+
+  // Create instructions
+  MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
+  MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test instruction embeddings
+  auto AddVector = Embedder->getMInstVector(*AddInstr);
+  auto SubVector = Embedder->getMInstVector(*SubInstr);
+
+  float ExpectedWeight = mir2vec::OpcWeight;
+  // ADD should have the embedding from vocabulary
+  EXPECT_TRUE(
+      AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
+
+  // SUB should have zero embedding (not in vocabulary)
+  EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f)));
+
+  // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
+  // 0.0] = [1.0, 1.0] * weight
+  const auto &MBBVector = Embedder->getMBBVector(*MBB);
+  Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
+  EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+}
 } // namespace



More information about the llvm-commits mailing list