[llvm] [MIR2Vec] Add embedder for machine instructions (PR #162161)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 7 16:46:21 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162161

>From c1747fb1a5594af839d44dde6324605a848c1114 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 6 Oct 2025 21:15:14 +0000
Subject: [PATCH] MIR2Vec embedding

---
 llvm/include/llvm/CodeGen/MIR2Vec.h           | 108 ++++++
 llvm/include/llvm/CodeGen/Passes.h            |   4 +
 llvm/include/llvm/InitializePasses.h          |   1 +
 llvm/lib/CodeGen/CodeGen.cpp                  |   1 +
 llvm/lib/CodeGen/MIR2Vec.cpp                  | 195 ++++++++++-
 .../Inputs/mir2vec_dummy_3D_vocab.json        |  22 ++
 llvm/test/CodeGen/MIR2Vec/if-else.mir         | 144 ++++++++
 .../MIR2Vec/mir2vec-basic-symbolic.mir        |  76 ++++
 llvm/tools/llc/llc.cpp                        |  15 +
 llvm/unittests/CodeGen/MIR2VecTest.cpp        | 324 ++++++++++++++++--
 10 files changed, 863 insertions(+), 27 deletions(-)
 create mode 100644 llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
 create mode 100644 llvm/test/CodeGen/MIR2Vec/if-else.mir
 create mode 100644 llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..ebafe4ccddff3 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -51,11 +51,21 @@ class LLVMContext;
 class MIR2VecVocabLegacyAnalysis;
 class TargetInstrInfo;
 
+enum class MIR2VecKind { Symbolic };
+
 namespace mir2vec {
+
+// Forward declarations
+class MIREmbedder;
+class SymbolicMIREmbedder;
+
 extern llvm::cl::OptionCategory MIR2VecCategory;
 extern cl::opt<float> OpcWeight;
 
 using Embedding = ir2vec::Embedding;
+using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
+using MachineBlockEmbeddingsMap =
+    DenseMap<const MachineBasicBlock *, Embedding>;
 
 /// Class for storing and accessing the MIR2Vec vocabulary.
 /// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -132,6 +142,79 @@ class MIRVocabulary {
     assert(isValid() && "Invalid vocabulary");
     return Storage.size();
   }
+
+  /// Create a dummy vocabulary for testing purposes.
+  static MIRVocabulary createDummyVocabForTest(const TargetInstrInfo &TII,
+                                               unsigned Dim = 1);
+};
+
+/// Base class for MIR embedders
+class MIREmbedder {
+protected:
+  const MachineFunction &MF;
+  const MIRVocabulary &Vocab;
+
+  /// Dimension of the embeddings; Captured from the vocabulary
+  const unsigned Dimension;
+
+  /// Weight for opcode embeddings
+  const float OpcWeight;
+
+  // Utility maps - these are used to store the vector representations of
+  // instructions, basic blocks and functions.
+  mutable Embedding MFuncVector;
+  mutable MachineBlockEmbeddingsMap MBBVecMap;
+  mutable MachineInstEmbeddingsMap MInstVecMap;
+
+  MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab);
+
+  /// Function to compute embeddings. It generates embeddings for all
+  /// the instructions and basic blocks in the function F.
+  void computeEmbeddings() const;
+
+  /// Function to compute the embedding for a given basic block.
+  /// Specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings(const MachineBasicBlock &MBB) const = 0;
+
+public:
+  virtual ~MIREmbedder() = default;
+
+  /// Factory method to create an Embedder object of the specified kind
+  /// Returns nullptr if the requested kind is not supported.
+  static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
+                                             const MachineFunction &MF,
+                                             const MIRVocabulary &Vocab);
+
+  /// Returns a map containing machine instructions and the corresponding
+  /// embeddings for the machine function MF if it has been computed. If not, it
+  /// computes the embeddings for MF and returns the map.
+  const MachineInstEmbeddingsMap &getMInstVecMap() const;
+
+  /// Returns a map containing machine basic block and the corresponding
+  /// embeddings for the machine function MF if it has been computed. If not, it
+  /// computes the embeddings for MF and returns the map.
+  const MachineBlockEmbeddingsMap &getMBBVecMap() const;
+
+  /// Returns the embedding for a given machine basic block in the machine
+  /// function MF if it has been computed. If not, it computes the embedding for
+  /// MBB and returns it.
+  const Embedding &getMBBVector(const MachineBasicBlock &MBB) const;
+
+  /// Computes and returns the embedding for the current machine function.
+  const Embedding &getMFunctionVector() const;
+};
+
+/// Class for computing Symbolic embeddings
+/// Symbolic embeddings are constructed based on the entity-level
+/// representations obtained from the MIR Vocabulary.
+class SymbolicMIREmbedder : public MIREmbedder {
+private:
+  void computeEmbeddings(const MachineBasicBlock &MBB) const override;
+
+public:
+  SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
+  static std::unique_ptr<SymbolicMIREmbedder>
+  create(const MachineFunction &MF, const MIRVocabulary &Vocab);
 };
 
 } // namespace mir2vec
@@ -181,6 +264,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
   }
 };
 
+/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
+/// and instructions
+class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
+  raw_ostream &OS;
+
+public:
+  static char ID;
+  explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
+      : MachineFunctionPass(ID), OS(OS) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MIR2VecVocabLegacyAnalysis>();
+    AU.setPreservesAll();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override {
+    return "MIR2Vec Embedder Printer Pass";
+  }
+};
+
+/// Create a machine pass that prints MIR2Vec embeddings
+MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_MIR2VEC_H
\ No newline at end of file
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 272b4acf950c5..7fae550d8d170 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
 LLVM_ABI MachineFunctionPass *
 createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
 
+/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
+/// machine functions, basic blocks and instructions.
+LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 /// StackFramePrinter pass - This pass prints out the machine function's
 /// stack frame to the given stream as a debugging tool.
 LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index cd774e7888e64..d507ba267d791 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -222,6 +222,7 @@ LLVM_ABI void
 initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
 LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp
index c438eaeb29d1e..9795a0b707fd3 100644
--- a/llvm/lib/CodeGen/CodeGen.cpp
+++ b/llvm/lib/CodeGen/CodeGen.cpp
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
   initializeMachineUniformityAnalysisPassPass(Registry);
   initializeMIR2VecVocabLegacyAnalysisPass(Registry);
   initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
+  initializeMIR2VecPrinterLegacyPassPass(Registry);
   initializeMachineUniformityInfoPrinterPassPass(Registry);
   initializeMachineVerifierLegacyPassPass(Registry);
   initializeObjCARCContractLegacyPassPass(Registry);
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..ffd076331c7d2 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -41,11 +41,18 @@ static cl::opt<std::string>
 cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
                          cl::desc("Weight for machine opcode embeddings"),
                          cl::cat(MIR2VecCategory));
+cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
+    "mir2vec-kind", cl::Optional,
+    cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
+                          "Generate symbolic embeddings for MIR")),
+    cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
+    cl::cat(MIR2VecCategory));
+
 } // namespace mir2vec
 } // namespace llvm
 
 //===----------------------------------------------------------------------===//
-// Vocabulary Implementation
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
                     << " unique base opcodes\n");
 }
 
+MIRVocabulary MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
+                                                     unsigned Dim) {
+  assert(Dim > 0 && "Dimension must be greater than zero");
+
+  float DummyVal = 0.1f;
+
+  // Create a temporary vocabulary instance to build canonical mapping
+  MIRVocabulary TempVocab({}, &TII);
+  TempVocab.buildCanonicalOpcodeMapping();
+
+  // Create dummy embeddings for all canonical opcode names
+  VocabMap DummyVocabMap;
+  for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
+    // Create dummy embedding filled with DummyVal
+    Embedding DummyEmbedding(Dim, DummyVal);
+    DummyVocabMap[COpcodeName] = DummyEmbedding;
+    DummyVal += 0.1f;
+  }
+
+  // Create and return vocabulary with dummy embeddings
+  return MIRVocabulary(std::move(DummyVocabMap), &TII);
+}
+
 //===----------------------------------------------------------------------===//
 // MIR2VecVocabLegacyAnalysis Implementation
 //===----------------------------------------------------------------------===//
@@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printer Passes Implementation
+// MIREmbedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+MIREmbedder::MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
+    : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
+      OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
+
+std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
+                                                 const MachineFunction &MF,
+                                                 const MIRVocabulary &Vocab) {
+  switch (Mode) {
+  case MIR2VecKind::Symbolic:
+    return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+  }
+  return nullptr;
+}
+
+const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const {
+  if (MInstVecMap.empty())
+    computeEmbeddings();
+  return MInstVecMap;
+}
+
+const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap() const {
+  if (MBBVecMap.empty())
+    computeEmbeddings();
+  return MBBVecMap;
+}
+
+const Embedding &MIREmbedder::getMBBVector(const MachineBasicBlock &BB) const {
+  auto It = MBBVecMap.find(&BB);
+  if (It != MBBVecMap.end())
+    return It->second;
+  computeEmbeddings(BB);
+  return MBBVecMap[&BB];
+}
+
+const Embedding &MIREmbedder::getMFunctionVector() const {
+  // Currently, we always (re)compute the embeddings for the function.
+  // This is cheaper than caching the vector.
+  computeEmbeddings();
+  return MFuncVector;
+}
+
+void MIREmbedder::computeEmbeddings() const {
+  // Reset function vector to zero before recomputing
+  MFuncVector = Embedding(Dimension, 0.0);
+
+  // Consider all machine basic blocks in the function
+  for (const auto &MBB : MF) {
+    computeEmbeddings(MBB);
+    MFuncVector += MBBVecMap[&MBB];
+  }
+}
+
+SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
+                                         const MIRVocabulary &Vocab)
+    : MIREmbedder(MF, Vocab) {}
+
+std::unique_ptr<SymbolicMIREmbedder>
+SymbolicMIREmbedder::create(const MachineFunction &MF,
+                            const MIRVocabulary &Vocab) {
+  return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+}
+
+void SymbolicMIREmbedder::computeEmbeddings(
+    const MachineBasicBlock &MBB) const {
+  Embedding MBBVector(Dimension, 0);
+
+  // Get instruction info for opcode name resolution
+  const auto &Subtarget = MF.getSubtarget();
+  const auto *TII = Subtarget.getInstrInfo();
+  if (!TII) {
+    MF.getFunction().getContext().emitError(
+        "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
+    return;
+  }
+
+  // Process each machine instruction in the basic block
+  for (const auto &MI : MBB) {
+    // Skip debug instructions and other metadata
+    if (MI.isDebugInstr())
+      continue;
+
+    // Todo: Add operand/argument contributions
+
+    // Store the instruction embedding
+    auto InstVector = Vocab[MI.getOpcode()];
+    MInstVecMap[&MI] = InstVector;
+    MBBVector += InstVector;
+  }
+
+  // Store the basic block embedding
+  MBBVecMap[&MBB] = MBBVector;
+}
+
+//===----------------------------------------------------------------------===//
+// Printer Passes
 //===----------------------------------------------------------------------===//
 
 char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -304,3 +431,67 @@ MachineFunctionPass *
 llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
   return new MIR2VecVocabPrinterLegacyPass(OS);
 }
+
+char MIR2VecPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                      "MIR2Vec Embedder Printer Pass", false, true)
+INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                    "MIR2Vec Embedder Printer Pass", false, true)
+
+bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
+  auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
+  auto MIRVocab = Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
+
+  if (!MIRVocab.isValid()) {
+    OS << "MIR2Vec Embedder Printer: Invalid vocabulary for function "
+       << MF.getName() << "\n";
+    return false;
+  }
+
+  auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
+  if (!Emb) {
+    OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
+       << "\n";
+    return false;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+  OS << "Machine Function vector: ";
+  Emb->getMFunctionVector().print(OS);
+
+  OS << "Machine basic block vectors:\n";
+  const auto &MBBMap = Emb->getMBBVecMap();
+  for (const MachineBasicBlock &MBB : MF) {
+    auto It = MBBMap.find(&MBB);
+    if (It != MBBMap.end()) {
+      OS << "Machine basic block: " << MBB.getFullName() << ":\n";
+      It->second.print(OS);
+    }
+  }
+
+  OS << "Machine instruction vectors:\n";
+  const auto &MInstMap = Emb->getMInstVecMap();
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions as they are not
+      // embedded
+      if (MI.isDebugInstr())
+        continue;
+
+      auto It = MInstMap.find(&MI);
+      if (It != MInstMap.end()) {
+        OS << "Machine instruction: ";
+        MI.print(OS);
+        It->second.print(OS);
+      }
+    }
+  }
+
+  return false;
+}
+
+MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
+  return new MIR2VecPrinterLegacyPass(OS);
+}
diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
new file mode 100644
index 0000000000000..5de715bf80917
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
@@ -0,0 +1,22 @@
+{
+  "entities": {
+    "KILL": [0.1, 0.2, 0.3],
+    "MOV": [0.4, 0.5, 0.6],
+    "LEA": [0.7, 0.8, 0.9],
+    "RET": [1.0, 1.1, 1.2],
+    "ADD": [1.3, 1.4, 1.5],
+    "SUB": [1.6, 1.7, 1.8],
+    "IMUL": [1.9, 2.0, 2.1],
+    "AND": [2.2, 2.3, 2.4],
+    "OR": [2.5, 2.6, 2.7],
+    "XOR": [2.8, 2.9, 3.0],
+    "CMP": [3.1, 3.2, 3.3],
+    "TEST": [3.4, 3.5, 3.6],
+    "JMP": [3.7, 3.8, 3.9],
+    "CALL": [4.0, 4.1, 4.2],
+    "PUSH": [4.3, 4.4, 4.5],
+    "POP": [4.6, 4.7, 4.8],
+    "NOP": [4.9, 5.0, 5.1],
+    "COPY": [5.2, 5.3, 5.4]
+  }
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir
new file mode 100644
index 0000000000000..2accf476f7c4d
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir
@@ -0,0 +1,144 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  
+  define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %retval = alloca i32, align 4
+    %a.addr = alloca i32, align 4
+    %b.addr = alloca i32, align 4
+    store i32 %a, ptr %a.addr, align 4
+    store i32 %b, ptr %b.addr, align 4
+    %0 = load i32, ptr %a.addr, align 4
+    %1 = load i32, ptr %b.addr, align 4
+    %cmp = icmp sgt i32 %0, %1
+    br i1 %cmp, label %if.then, label %if.else
+  
+  if.then:                                          ; preds = %entry
+    %2 = load i32, ptr %b.addr, align 4
+    store i32 %2, ptr %retval, align 4
+    br label %return
+  
+  if.else:                                          ; preds = %entry
+    %3 = load i32, ptr %a.addr, align 4
+    store i32 %3, ptr %retval, align 4
+    br label %return
+  
+  return:                                           ; preds = %if.else, %if.then
+    %4 = load i32, ptr %retval, align 4
+    ret i32 %4
+  }
+...
+---
+name:            abc
+alignment:       16
+exposesReturnsTwice: false
+legalized:       false
+regBankSelected: false
+selected:        false
+failedISel:      false
+tracksRegLiveness: true
+hasWinCFI:       false
+noPhis:          false
+isSSA:           true
+noVRegs:         false
+hasFakeUses:     false
+callsEHReturn:   false
+callsUnwindInit: false
+hasEHContTarget: false
+hasEHScopes:     false
+hasEHFunclets:   false
+isOutlined:      false
+debugInstrRef:   true
+failsVerification: false
+tracksDebugUserValues: false
+registers:
+  - { id: 0, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 1, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 2, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 3, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 4, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 5, class: gr32, preferred-register: '', flags: [  ] }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+frameInfo:
+  isFrameAddressTaken: false
+  isReturnAddressTaken: false
+  hasStackMap:     false
+  hasPatchPoint:   false
+  stackSize:       0
+  offsetAdjustment: 0
+  maxAlignment:    4
+  adjustsStack:    false
+  hasCalls:        false
+  stackProtector:  ''
+  functionContext: ''
+  maxCallFrameSize: 4294967295
+  cvBytesOfCalleeSavedRegisters: 0
+  hasOpaqueSPAdjustment: false
+  hasVAStart:      false
+  hasMustTailInVarArgFunc: false
+  hasTailCall:     false
+  isCalleeSavedInfoValid: false
+  localFrameSize:  0
+fixedStack:      []
+stack:
+  - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+entry_values:    []
+callSites:       []
+debugValueSubstitutions: []
+constants:       []
+machineFunctionInfo:
+  amxProgModel:    None
+body:             |
+  bb.0.entry:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr)
+    MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr)
+    %2:gr32 = SUB32rr %0, %1, implicit-def $eflags
+    JCC_1 %bb.2, 14, implicit $eflags
+    JMP_1 %bb.1
+  
+  bb.1.if.then:
+    successors: %bb.3(0x80000000)
+  
+    %4:gr32 = MOV32rm %stack.2.b.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.b.addr)
+    MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %4 :: (store (s32) into %ir.retval)
+    JMP_1 %bb.3
+  
+  bb.2.if.else:
+    successors: %bb.3(0x80000000)
+  
+    %3:gr32 = MOV32rm %stack.1.a.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.a.addr)
+    MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %3 :: (store (s32) into %ir.retval)
+  
+  bb.3.return:
+    %5:gr32 = MOV32rm %stack.0.retval, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.retval)
+    $eax = COPY %5
+    RET 0, $eax
+...
+
+# CHECK: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: abc:entry:
+# CHECK-NEXT:  [ 16.50  17.10  17.70 ]
+# CHECK-NEXT: Machine basic block: abc:if.then:
+# CHECK-NEXT:  [ 4.50  4.80  5.10 ]
+# CHECK-NEXT: Machine basic block: abc:if.else:
+# CHECK-NEXT:  [ 0.80  1.00  1.20 ]
+# CHECK-NEXT: Machine basic block: abc:return:
+# CHECK-NEXT:  [ 6.60  6.90  7.20 ]
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
new file mode 100644
index 0000000000000..44240affb2206
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
@@ -0,0 +1,76 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  
+  define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %sum = add nsw i32 %a, %b
+    %result = mul nsw i32 %sum, 2
+    ret i32 %result
+  }
+  
+  define dso_local void @simple_function() {
+  entry:
+    ret void
+  }
+...
+---
+name:            add_function
+alignment:       16
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gr32 }
+  - { id: 1, class: gr32 }
+  - { id: 2, class: gr32 }
+  - { id: 3, class: gr32 }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+body:             |
+  bb.0.entry:
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags
+    %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags
+    $eax = COPY %3
+    RET 0, $eax
+
+---
+name:            simple_function
+alignment:       16
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    RET 0
+
+# CHECK: MIR2Vec embeddings for machine function add_function:
+# CHECK: Function vector:  [ 19.20  19.80  20.40 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: add_function:entry:
+# CHECK-NEXT:  [ 19.20  19.80  20.40 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: %1:gr32 = COPY $esi
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: %0:gr32 = COPY $edi
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: %2:gr32 = nsw ADD32rr %0:gr32(tied-def 0), %1:gr32, implicit-def dead $eflags
+# CHECK-NEXT:  [ 1.30  1.40  1.50 ]
+# CHECK-NEXT: Machine instruction: %3:gr32 = ADD32rr %2:gr32(tied-def 0), %2:gr32, implicit-def dead $eflags
+# CHECK-NEXT:  [ 1.30  1.40  1.50 ]
+# CHECK-NEXT: Machine instruction: $eax = COPY %3:gr32
+# CHECK-NEXT:  [ 5.20  5.30  5.40 ]
+# CHECK-NEXT: Machine instruction: RET 0, $eax
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
+
+# CHECK: MIR2Vec embeddings for machine function simple_function:
+# CHECK-NEXT:Function vector:  [ 1.00  1.10  1.20 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: simple_function:entry:
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: RET 0
+# CHECK-NEXT:  [ 1.00  1.10  1.20 ]
\ No newline at end of file
diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp
index 7551a80fbd5d8..d85b632fa7c3d 100644
--- a/llvm/tools/llc/llc.cpp
+++ b/llvm/tools/llc/llc.cpp
@@ -171,6 +171,11 @@ static cl::opt<bool>
                       cl::desc("Print MIR2Vec vocabulary contents"),
                       cl::init(false));
 
+static cl::opt<bool>
+    PrintMIR2Vec("print-mir2vec", cl::Hidden,
+                 cl::desc("Print MIR2Vec embeddings for functions"),
+                 cl::init(false));
+
 static cl::list<std::string> IncludeDirs("I", cl::desc("include search path"));
 
 static cl::opt<bool> RemarksWithHotness(
@@ -736,6 +741,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
         PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
       }
 
+      // Add MIR2Vec printer if requested
+      if (PrintMIR2Vec) {
+        PM.add(createMIR2VecPrinterLegacyPass(errs()));
+      }
+
       PM.add(createFreeMachineFunctionPass());
     } else {
       if (Target->addPassesToEmitFile(PM, *OS, DwoOut ? &DwoOut->os() : nullptr,
@@ -749,6 +759,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
       if (PrintMIR2VecVocab) {
         PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
       }
+
+      // Add MIR2Vec printer if requested
+      if (PrintMIR2Vec) {
+        PM.add(createMIR2VecPrinterLegacyPass(errs()));
+      }
     }
 
     Target->getObjFileLowering()->Initialize(MMIWP->getMMI().getContext(),
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..c8aef8e7aba5f 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -81,6 +81,9 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
       return;
     }
 
+    // Set the data layout to match the target machine
+    M->setDataLayout(TM->createDataLayout());
+
     // Create a dummy function to get subtarget info
     FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
     Function *F =
@@ -93,16 +96,32 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
       return;
     }
   }
-};
 
-// Function to find an opcode by name
-static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
-  for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
-    if (TII->getName(Opcode) == Name)
-      return Opcode;
+  // Find an opcode by name
+  int findOpcodeByName(StringRef Name) {
+    for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+      if (TII->getName(Opcode) == Name)
+        return Opcode;
+    }
+    return -1; // Not found
   }
-  return -1; // Not found
-}
+
+  // Create a vocabulary with specific opcodes and embeddings
+  MIRVocabulary
+  createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes,
+                  unsigned dimension = 2) {
+    VocabMap VMap;
+    for (const auto &[name, value] : opcodes)
+      VMap[name] = Embedding(dimension, value);
+    return MIRVocabulary(std::move(VMap), TII);
+  }
+
+  // Create empty/invalid vocabulary
+  MIRVocabulary createEmptyVocab() {
+    VocabMap EmptyVMap;
+    return MIRVocabulary(std::move(EmptyVMap), TII);
+  }
+};
 
 TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   // Test that same base opcodes get same canonical indices
@@ -115,10 +134,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
 
   // Create a MIRVocabulary instance to test the mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VMap;
   Embedding Val = Embedding(64, 1.0f);
-  VMap["ADD"] = Val;
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocab = createTestVocab({{"ADD", 1.0f}}, 64);
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -149,16 +166,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
             6880u); // X86 has >6880 unique base opcodes
 
   // Check that the embeddings for opcodes not in the vocab are zero vectors
-  int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+  int Add32rrOpcode = findOpcodeByName("ADD32rr");
   ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
   EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
 
-  int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+  int Sub32rrOpcode = findOpcodeByName("SUB32rr");
   ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
   EXPECT_TRUE(
       TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
 
-  int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+  int Mov32rrOpcode = findOpcodeByName("MOV32rr");
   ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
   EXPECT_TRUE(
       TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
@@ -171,15 +188,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
   // Create a MIRVocabulary instance to test deterministic mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VMap;
-  VMap["ADD"] = Embedding(64, 1.0f);
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocab = createTestVocab({{"ADD", 1.0f}}, 64);
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
-
-  EXPECT_EQ(Index1, Index2);
   EXPECT_EQ(Index2, Index3);
 
   // Test across multiple runs
@@ -191,11 +204,8 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
 // Test MIRVocabulary construction
 TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
-  VocabMap VMap;
-  VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
-  VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
-
-  MIRVocabulary Vocab(std::move(VMap), TII);
+  // Test MIRVocabulary with embeddings via VocabMap
+  auto Vocab = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128);
   EXPECT_TRUE(Vocab.isValid());
   EXPECT_EQ(Vocab.getDimension(), 128u);
 
@@ -214,4 +224,268 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
   EXPECT_GT(Count, 0u);
 }
 
-} // namespace
\ No newline at end of file
+// Test embedder with invalid vocabulary
+TEST_F(MIR2VecVocabTestFixture, InvalidVocabulary) {
+  // Create an invalid vocabulary (empty)
+  auto InvalidVocab = createEmptyVocab();
+
+  // Verify vocabulary is invalid
+  EXPECT_FALSE(InvalidVocab.isValid());
+  EXPECT_EQ(InvalidVocab.getDimension(), 0u);
+}
+
+// Fixture for embedding related tests
+class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
+protected:
+  std::unique_ptr<MachineModuleInfo> MMI;
+  MachineFunction *MF = nullptr;
+
+  void SetUp() override {
+    MIR2VecVocabTestFixture::SetUp();
+
+    // Create a dummy function for MachineFunction
+    FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
+    Function *F =
+        Function::Create(FT, Function::ExternalLinkage, "test", M.get());
+
+    MMI = std::make_unique<MachineModuleInfo>(TM.get());
+    MF = &MMI->getOrCreateMachineFunction(*F);
+  }
+
+  void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
+
+  // Create a machine instruction
+  MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
+    const MCInstrDesc &Desc = TII->get(Opcode);
+    // Create instruction - operands don't affect opcode-based embeddings
+    MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
+    return MI;
+  }
+
+  MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
+                                   const char *OpcodeName) {
+    int Opcode = findOpcodeByName(OpcodeName);
+    if (Opcode == -1)
+      return nullptr;
+    return createMachineInstr(MBB, Opcode);
+  }
+
+  void createMachineInstrs(MachineBasicBlock &MBB,
+                           std::initializer_list<const char *> Opcodes) {
+    for (const char *OpcodeName : Opcodes) {
+      MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
+      ASSERT_TRUE(MI != nullptr);
+    }
+  }
+};
+
+// Test factory method for creating embedder
+TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
+  auto V = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+  auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
+  EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
+  auto V = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+
+  // static_cast an invalid int to IR2VecKind
+  auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
+  EXPECT_FALSE(static_cast<bool>(Result));
+}
+
+// Test SymbolicMIREmbedder with simple target opcodes
+TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
+  // Create a test vocabulary with specific values
+  auto Vocab = createTestVocab(
+      {
+          {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
+          {"RET", 2.0f},  // [2.0, 2.0, 2.0, 2.0]
+          {"TRAP", 3.0f}  // [3.0, 3.0, 3.0, 3.0]
+      },
+      4);
+
+  // Create a basic block using fixture's MF
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Use real X86 opcodes that should exist and not be pseudo
+  auto NoopInst = createMachineInstr(*MBB, "NOOP");
+  ASSERT_TRUE(NoopInst != nullptr);
+
+  auto RetInst = createMachineInstr(*MBB, "RET64");
+  ASSERT_TRUE(RetInst != nullptr);
+
+  auto TrapInst = createMachineInstr(*MBB, "TRAP");
+  ASSERT_TRUE(TrapInst != nullptr);
+
+  // Verify these are not pseudo instructions
+  ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
+  ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
+  ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test instruction embeddings
+  const auto &MInstMap = Embedder->getMInstVecMap();
+  EXPECT_EQ(MInstMap.size(), 3u); // Should have 3 instructions
+
+  // Check individual instruction embeddings
+  auto NoopIt = MInstMap.find(NoopInst);
+  auto RetIt = MInstMap.find(RetInst);
+  auto TrapIt = MInstMap.find(TrapInst);
+
+  ASSERT_NE(NoopIt, MInstMap.end());
+  ASSERT_NE(RetIt, MInstMap.end());
+  ASSERT_NE(TrapIt, MInstMap.end());
+
+  // Verify embeddings match expected values (accounting for weight scaling)
+  float ExpectedWeight = ::OpcWeight; // Global weight from command line
+  EXPECT_TRUE(
+      NoopIt->second.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
+  EXPECT_TRUE(
+      RetIt->second.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
+  EXPECT_TRUE(
+      TrapIt->second.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
+
+  // Test basic block embedding (should be sum of instruction embeddings)
+  const auto &MBBMap = Embedder->getMBBVecMap();
+  EXPECT_EQ(MBBMap.size(), 1u);
+
+  auto MBBIt = MBBMap.find(MBB);
+  ASSERT_NE(MBBIt, MBBMap.end());
+
+  // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
+  // weight = [6, 6, 6, 6] * weight
+  Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
+  EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedMBBVector));
+
+  // Test function embedding (should equal MBB embedding since we have one MBB)
+  const Embedding &MFuncVector = Embedder->getMFunctionVector();
+  EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
+}
+
+// Test embedder with multiple basic blocks
+TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
+  // Create a test vocabulary
+  auto Vocab = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}});
+
+  // Create two basic blocks using fixture's MF
+  MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
+  MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB1);
+  MF->push_back(MBB2);
+
+  createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
+  createMachineInstr(*MBB2, "TRAP");
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test basic block embeddings
+  const auto &MBBMap = Embedder->getMBBVecMap();
+  EXPECT_EQ(MBBMap.size(), 2u);
+
+  auto MBB1It = MBBMap.find(MBB1);
+  auto MBB2It = MBBMap.find(MBB2);
+  ASSERT_NE(MBB1It, MBBMap.end());
+  ASSERT_NE(MBB2It, MBBMap.end());
+
+  float ExpectedWeight = ::OpcWeight;
+  // BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
+  Embedding ExpectedBB1Vector(2, 2.0f * ExpectedWeight);
+  EXPECT_TRUE(MBB1It->second.approximatelyEquals(ExpectedBB1Vector));
+
+  // BB2: TRAP = [2, 2] * weight
+  Embedding ExpectedBB2Vector(2, 2.0f * ExpectedWeight);
+  EXPECT_TRUE(MBB2It->second.approximatelyEquals(ExpectedBB2Vector));
+
+  // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
+  const Embedding &MFVector = Embedder->getMFunctionVector();
+  Embedding ExpectedFuncVector(2, 4.0f * ExpectedWeight);
+  EXPECT_TRUE(MFVector.approximatelyEquals(ExpectedFuncVector));
+}
+
+// Test embedder with empty basic block
+TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
+
+  // Create an empty basic block
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Create embedder
+  auto Vocab = MIRVocabulary::createDummyVocabForTest(*TII, 2);
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test that empty BB has zero embedding
+  const auto &MBBMap = Embedder->getMBBVecMap();
+  EXPECT_EQ(MBBMap.size(), 1u);
+
+  auto MBBIt = MBBMap.find(MBB);
+  ASSERT_NE(MBBIt, MBBMap.end());
+
+  // Empty BB should have zero embedding
+  Embedding ExpectedBBVector(2, 0.0f);
+  EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedBBVector));
+
+  // Function embedding should also be zero
+  const Embedding &MFVector = Embedder->getMFunctionVector();
+  EXPECT_TRUE(MFVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test embedder with opcodes not in vocabulary
+TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
+  // Create a test vocabulary with limited entries
+  // SUB is intentionally not included
+  auto Vocab = createTestVocab({{"ADD", 1.0f}});
+
+  // Create a basic block
+  MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+  MF->push_back(MBB);
+
+  // Find opcodes
+  int AddOpcode = findOpcodeByName("ADD32rr");
+  int SubOpcode = findOpcodeByName("SUB32rr");
+
+  ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
+  ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
+
+  // Create instructions
+  MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
+  MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
+
+  // Create embedder
+  auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+  ASSERT_TRUE(Embedder != nullptr);
+
+  // Test instruction embeddings
+  const auto &MIMap = Embedder->getMInstVecMap();
+
+  auto AddIt = MIMap.find(AddInstr);
+  auto SubIt = MIMap.find(SubInstr);
+
+  ASSERT_NE(AddIt, MIMap.end());
+  ASSERT_NE(SubIt, MIMap.end());
+
+  float ExpectedWeight = ::OpcWeight;
+  // ADD should have the embedding from vocabulary
+  EXPECT_TRUE(
+      AddIt->second.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
+
+  // SUB should have zero embedding (not in vocabulary)
+  EXPECT_TRUE(SubIt->second.approximatelyEquals(Embedding(2, 0.0f)));
+
+  // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
+  // 0.0] = [1.0, 1.0] * weight
+  const auto &MBBMap = Embedder->getMBBVecMap();
+  auto MBBIt = MBBMap.find(MBB);
+  ASSERT_NE(MBBIt, MBBMap.end());
+
+  Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
+  EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedBBVector));
+}
+} // namespace



More information about the llvm-commits mailing list