[llvm] [MIR2Vec] Add embedder for machine instructions (PR #162161)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 9 15:20:50 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162161
>From 19bdf147ecf26c985f94d024c17d4818c9db775d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 6 Oct 2025 21:15:14 +0000
Subject: [PATCH] MIR2Vec embedding
---
llvm/include/llvm/CodeGen/MIR2Vec.h | 113 ++++++-
llvm/include/llvm/CodeGen/Passes.h | 4 +
llvm/include/llvm/InitializePasses.h | 1 +
llvm/lib/CodeGen/CodeGen.cpp | 1 +
llvm/lib/CodeGen/MIR2Vec.cpp | 155 ++++++++-
.../Inputs/mir2vec_dummy_3D_vocab.json | 22 ++
llvm/test/CodeGen/MIR2Vec/if-else.mir | 144 +++++++++
.../MIR2Vec/mir2vec-basic-symbolic.mir | 76 +++++
llvm/tools/llc/llc.cpp | 15 +
llvm/unittests/CodeGen/MIR2VecTest.cpp | 295 ++++++++++++++++--
10 files changed, 797 insertions(+), 29 deletions(-)
create mode 100644 llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
create mode 100644 llvm/test/CodeGen/MIR2Vec/if-else.mir
create mode 100644 llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 7b1b5d9aee15d..f6b0571f5dac6 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -52,11 +52,21 @@ class LLVMContext;
class MIR2VecVocabLegacyAnalysis;
class TargetInstrInfo;
+enum class MIR2VecKind { Symbolic };
+
namespace mir2vec {
+
+// Forward declarations
+class MIREmbedder;
+class SymbolicMIREmbedder;
+
extern llvm::cl::OptionCategory MIR2VecCategory;
extern cl::opt<float> OpcWeight;
using Embedding = ir2vec::Embedding;
+using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
+using MachineBlockEmbeddingsMap =
+ DenseMap<const MachineBasicBlock *, Embedding>;
/// Class for storing and accessing the MIR2Vec vocabulary.
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -107,19 +117,91 @@ class MIRVocabulary {
const_iterator end() const { return Storage.end(); }
- /// Total number of entries in the vocabulary
- size_t getCanonicalSize() const { return Storage.size(); }
-
MIRVocabulary() = delete;
/// Factory method to create MIRVocabulary from vocabulary map
static Expected<MIRVocabulary> create(VocabMap &&Entries,
const TargetInstrInfo &TII);
+ /// Create a dummy vocabulary for testing purposes.
+ static Expected<MIRVocabulary>
+ createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
+
+ /// Total number of entries in the vocabulary
+ size_t getCanonicalSize() const { return Storage.size(); }
+
private:
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
};
+/// Base class for MIR embedders
+class MIREmbedder {
+protected:
+ const MachineFunction &MF;
+ const MIRVocabulary &Vocab;
+
+ /// Dimension of the embeddings; Captured from the vocabulary
+ const unsigned Dimension;
+
+ /// Weight for opcode embeddings
+ const float OpcWeight;
+
+ MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
+ : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
+ OpcWeight(mir2vec::OpcWeight) {}
+
+ /// Function to compute embeddings.
+ Embedding computeEmbeddings() const;
+
+ /// Function to compute the embedding for a given machine basic block.
+ Embedding computeEmbeddings(const MachineBasicBlock &MBB) const;
+
+ /// Function to compute the embedding for a given machine instruction.
+ /// Specific to the kind of embeddings being computed.
+ virtual Embedding computeEmbeddings(const MachineInstr &MI) const = 0;
+
+public:
+ virtual ~MIREmbedder() = default;
+
+ /// Factory method to create an Embedder object of the specified kind
+ /// Returns nullptr if the requested kind is not supported.
+ static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
+ const MachineFunction &MF,
+ const MIRVocabulary &Vocab);
+
+ /// Computes and returns the embedding for a given machine instruction MI in
+ /// the machine function MF.
+ Embedding getMInstVector(const MachineInstr &MI) const {
+ return computeEmbeddings(MI);
+ }
+
+ /// Computes and returns the embedding for a given machine basic block in the
+ /// machine function MF.
+ Embedding getMBBVector(const MachineBasicBlock &MBB) const {
+ return computeEmbeddings(MBB);
+ }
+
+ /// Computes and returns the embedding for the current machine function.
+ Embedding getMFunctionVector() const {
+ // Currently, we always (re)compute the embeddings for the function. This is
+ // cheaper than caching the vector.
+ return computeEmbeddings();
+ }
+};
+
+/// Class for computing Symbolic embeddings
+/// Symbolic embeddings are constructed based on the entity-level
+/// representations obtained from the MIR Vocabulary.
+class SymbolicMIREmbedder : public MIREmbedder {
+private:
+ Embedding computeEmbeddings(const MachineInstr &MI) const override;
+
+public:
+ SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
+ static std::unique_ptr<SymbolicMIREmbedder>
+ create(const MachineFunction &MF, const MIRVocabulary &Vocab);
+};
+
} // namespace mir2vec
/// Pass to analyze and populate MIR2Vec vocabulary from a module
@@ -166,6 +248,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
}
};
+/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
+/// and instructions
+class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
+ raw_ostream &OS;
+
+public:
+ static char ID;
+ explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
+ : MachineFunctionPass(ID), OS(OS) {}
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<MIR2VecVocabLegacyAnalysis>();
+ AU.setPreservesAll();
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+
+ StringRef getPassName() const override {
+ return "MIR2Vec Embedder Printer Pass";
+ }
+};
+
+/// Create a machine pass that prints MIR2Vec embeddings
+MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
} // namespace llvm
#endif // LLVM_CODEGEN_MIR2VEC_H
\ No newline at end of file
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 272b4acf950c5..7fae550d8d170 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
LLVM_ABI MachineFunctionPass *
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
+/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
+/// machine functions, basic blocks and instructions.
+LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
/// StackFramePrinter pass - This pass prints out the machine function's
/// stack frame to the given stream as a debugging tool.
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index cd774e7888e64..d507ba267d791 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -222,6 +222,7 @@ LLVM_ABI void
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp
index c438eaeb29d1e..9795a0b707fd3 100644
--- a/llvm/lib/CodeGen/CodeGen.cpp
+++ b/llvm/lib/CodeGen/CodeGen.cpp
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
initializeMachineUniformityAnalysisPassPass(Registry);
initializeMIR2VecVocabLegacyAnalysisPass(Registry);
initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
+ initializeMIR2VecPrinterLegacyPassPass(Registry);
initializeMachineUniformityInfoPrinterPassPass(Registry);
initializeMachineVerifierLegacyPassPass(Registry);
initializeObjCARCContractLegacyPassPass(Registry);
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index e85976547a2c2..2df14a75bf623 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/Module.h"
@@ -41,11 +42,18 @@ static cl::opt<std::string>
cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for machine opcode embeddings"),
cl::cat(MIR2VecCategory));
+cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
+ "mir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings for MIR")),
+ cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
+ cl::cat(MIR2VecCategory));
+
} // namespace mir2vec
} // namespace llvm
//===----------------------------------------------------------------------===//
-// Vocabulary Implementation
+// Vocabulary
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -191,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
<< " unique base opcodes\n");
}
+Expected<MIRVocabulary>
+MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
+ unsigned Dim) {
+ assert(Dim > 0 && "Dimension must be greater than zero");
+
+ float DummyVal = 0.1f;
+
+ // Create a temporary vocabulary instance to build canonical mapping
+ MIRVocabulary TempVocab({}, TII);
+ TempVocab.buildCanonicalOpcodeMapping();
+
+ // Create dummy embeddings for all canonical opcode names
+ VocabMap DummyVocabMap;
+ for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
+ // Create dummy embedding filled with DummyVal
+ Embedding DummyEmbedding(Dim, DummyVal);
+ DummyVocabMap[COpcodeName] = DummyEmbedding;
+ DummyVal += 0.1f;
+ }
+
+ // Create and return vocabulary with dummy embeddings
+ return MIRVocabulary::create(std::move(DummyVocabMap), TII);
+}
+
//===----------------------------------------------------------------------===//
// MIR2VecVocabLegacyAnalysis Implementation
//===----------------------------------------------------------------------===//
@@ -261,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
}
//===----------------------------------------------------------------------===//
-// Printer Passes Implementation
+// MIREmbedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
+ const MachineFunction &MF,
+ const MIRVocabulary &Vocab) {
+ switch (Mode) {
+ case MIR2VecKind::Symbolic:
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+ }
+ return nullptr;
+}
+
+Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
+ Embedding MBBVector(Dimension, 0);
+
+ // Get instruction info for opcode name resolution
+ const auto &Subtarget = MF.getSubtarget();
+ const auto *TII = Subtarget.getInstrInfo();
+ if (!TII) {
+ MF.getFunction().getContext().emitError(
+ "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
+ return MBBVector;
+ }
+
+ // Process each machine instruction in the basic block
+ for (const auto &MI : MBB) {
+ // Skip debug instructions and other metadata
+ if (MI.isDebugInstr())
+ continue;
+ MBBVector += computeEmbeddings(MI);
+ }
+
+ return MBBVector;
+}
+
+Embedding MIREmbedder::computeEmbeddings() const {
+ Embedding MFuncVector(Dimension, 0);
+
+ // Consider all reachable machine basic blocks in the function
+ for (const auto *MBB : depth_first(&MF))
+ MFuncVector += computeEmbeddings(*MBB);
+ return MFuncVector;
+}
+
+SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
+ const MIRVocabulary &Vocab)
+ : MIREmbedder(MF, Vocab) {}
+
+std::unique_ptr<SymbolicMIREmbedder>
+SymbolicMIREmbedder::create(const MachineFunction &MF,
+ const MIRVocabulary &Vocab) {
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+}
+
+Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
+ // Skip debug instructions and other metadata
+ if (MI.isDebugInstr())
+ return Embedding(Dimension, 0);
+
+ // Todo: Add operand/argument contributions
+
+ return Vocab[MI.getOpcode()];
+}
+
+//===----------------------------------------------------------------------===//
+// Printer Passes
//===----------------------------------------------------------------------===//
char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -300,3 +398,56 @@ MachineFunctionPass *
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
return new MIR2VecVocabPrinterLegacyPass(OS);
}
+
+char MIR2VecPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
+ "MIR2Vec Embedder Printer Pass", false, true)
+INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
+ "MIR2Vec Embedder Printer Pass", false, true)
+
+bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
+ auto VocabOrErr =
+ Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
+ assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
+ auto &MIRVocab = *VocabOrErr;
+
+ auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
+ if (!Emb) {
+ OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
+ << "\n";
+ return false;
+ }
+
+ OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+ OS << "Machine Function vector: ";
+ Emb->getMFunctionVector().print(OS);
+
+ OS << "Machine basic block vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ OS << "Machine basic block: " << MBB.getFullName() << ":\n";
+ Emb->getMBBVector(MBB).print(OS);
+ }
+
+ OS << "Machine instruction vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions as they are not
+ // embedded
+ if (MI.isDebugInstr())
+ continue;
+
+ OS << "Machine instruction: ";
+ MI.print(OS);
+ Emb->getMInstVector(MI).print(OS);
+ }
+ }
+
+ return false;
+}
+
+MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
+ return new MIR2VecPrinterLegacyPass(OS);
+}
diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
new file mode 100644
index 0000000000000..5de715bf80917
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
@@ -0,0 +1,22 @@
+{
+ "entities": {
+ "KILL": [0.1, 0.2, 0.3],
+ "MOV": [0.4, 0.5, 0.6],
+ "LEA": [0.7, 0.8, 0.9],
+ "RET": [1.0, 1.1, 1.2],
+ "ADD": [1.3, 1.4, 1.5],
+ "SUB": [1.6, 1.7, 1.8],
+ "IMUL": [1.9, 2.0, 2.1],
+ "AND": [2.2, 2.3, 2.4],
+ "OR": [2.5, 2.6, 2.7],
+ "XOR": [2.8, 2.9, 3.0],
+ "CMP": [3.1, 3.2, 3.3],
+ "TEST": [3.4, 3.5, 3.6],
+ "JMP": [3.7, 3.8, 3.9],
+ "CALL": [4.0, 4.1, 4.2],
+ "PUSH": [4.3, 4.4, 4.5],
+ "POP": [4.6, 4.7, 4.8],
+ "NOP": [4.9, 5.0, 5.1],
+ "COPY": [5.2, 5.3, 5.4]
+ }
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir
new file mode 100644
index 0000000000000..2accf476f7c4d
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir
@@ -0,0 +1,144 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+
+ define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) {
+ entry:
+ %retval = alloca i32, align 4
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca i32, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store i32 %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %b.addr, align 4
+ %cmp = icmp sgt i32 %0, %1
+ br i1 %cmp, label %if.then, label %if.else
+
+ if.then: ; preds = %entry
+ %2 = load i32, ptr %b.addr, align 4
+ store i32 %2, ptr %retval, align 4
+ br label %return
+
+ if.else: ; preds = %entry
+ %3 = load i32, ptr %a.addr, align 4
+ store i32 %3, ptr %retval, align 4
+ br label %return
+
+ return: ; preds = %if.else, %if.then
+ %4 = load i32, ptr %retval, align 4
+ ret i32 %4
+ }
+...
+---
+name: abc
+alignment: 16
+exposesReturnsTwice: false
+legalized: false
+regBankSelected: false
+selected: false
+failedISel: false
+tracksRegLiveness: true
+hasWinCFI: false
+noPhis: false
+isSSA: true
+noVRegs: false
+hasFakeUses: false
+callsEHReturn: false
+callsUnwindInit: false
+hasEHContTarget: false
+hasEHScopes: false
+hasEHFunclets: false
+isOutlined: false
+debugInstrRef: true
+failsVerification: false
+tracksDebugUserValues: false
+registers:
+ - { id: 0, class: gr32, preferred-register: '', flags: [ ] }
+ - { id: 1, class: gr32, preferred-register: '', flags: [ ] }
+ - { id: 2, class: gr32, preferred-register: '', flags: [ ] }
+ - { id: 3, class: gr32, preferred-register: '', flags: [ ] }
+ - { id: 4, class: gr32, preferred-register: '', flags: [ ] }
+ - { id: 5, class: gr32, preferred-register: '', flags: [ ] }
+liveins:
+ - { reg: '$edi', virtual-reg: '%0' }
+ - { reg: '$esi', virtual-reg: '%1' }
+frameInfo:
+ isFrameAddressTaken: false
+ isReturnAddressTaken: false
+ hasStackMap: false
+ hasPatchPoint: false
+ stackSize: 0
+ offsetAdjustment: 0
+ maxAlignment: 4
+ adjustsStack: false
+ hasCalls: false
+ stackProtector: ''
+ functionContext: ''
+ maxCallFrameSize: 4294967295
+ cvBytesOfCalleeSavedRegisters: 0
+ hasOpaqueSPAdjustment: false
+ hasVAStart: false
+ hasMustTailInVarArgFunc: false
+ hasTailCall: false
+ isCalleeSavedInfoValid: false
+ localFrameSize: 0
+fixedStack: []
+stack:
+ - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4,
+ stack-id: default, callee-saved-register: '', callee-saved-restored: true,
+ debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+ - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4,
+ stack-id: default, callee-saved-register: '', callee-saved-restored: true,
+ debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+ - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4,
+ stack-id: default, callee-saved-register: '', callee-saved-restored: true,
+ debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+entry_values: []
+callSites: []
+debugValueSubstitutions: []
+constants: []
+machineFunctionInfo:
+ amxProgModel: None
+body: |
+ bb.0.entry:
+ successors: %bb.1(0x40000000), %bb.2(0x40000000)
+ liveins: $edi, $esi
+
+ %1:gr32 = COPY $esi
+ %0:gr32 = COPY $edi
+ MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr)
+ MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr)
+ %2:gr32 = SUB32rr %0, %1, implicit-def $eflags
+ JCC_1 %bb.2, 14, implicit $eflags
+ JMP_1 %bb.1
+
+ bb.1.if.then:
+ successors: %bb.3(0x80000000)
+
+ %4:gr32 = MOV32rm %stack.2.b.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.b.addr)
+ MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %4 :: (store (s32) into %ir.retval)
+ JMP_1 %bb.3
+
+ bb.2.if.else:
+ successors: %bb.3(0x80000000)
+
+ %3:gr32 = MOV32rm %stack.1.a.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.a.addr)
+ MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %3 :: (store (s32) into %ir.retval)
+
+ bb.3.return:
+ %5:gr32 = MOV32rm %stack.0.retval, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.retval)
+ $eax = COPY %5
+ RET 0, $eax
+...
+
+# CHECK: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: abc:entry:
+# CHECK-NEXT: [ 16.50 17.10 17.70 ]
+# CHECK-NEXT: Machine basic block: abc:if.then:
+# CHECK-NEXT: [ 4.50 4.80 5.10 ]
+# CHECK-NEXT: Machine basic block: abc:if.else:
+# CHECK-NEXT: [ 0.80 1.00 1.20 ]
+# CHECK-NEXT: Machine basic block: abc:return:
+# CHECK-NEXT: [ 6.60 6.90 7.20 ]
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
new file mode 100644
index 0000000000000..44240affb2206
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir
@@ -0,0 +1,76 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+
+ define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) {
+ entry:
+ %sum = add nsw i32 %a, %b
+ %result = mul nsw i32 %sum, 2
+ ret i32 %result
+ }
+
+ define dso_local void @simple_function() {
+ entry:
+ ret void
+ }
+...
+---
+name: add_function
+alignment: 16
+tracksRegLiveness: true
+registers:
+ - { id: 0, class: gr32 }
+ - { id: 1, class: gr32 }
+ - { id: 2, class: gr32 }
+ - { id: 3, class: gr32 }
+liveins:
+ - { reg: '$edi', virtual-reg: '%0' }
+ - { reg: '$esi', virtual-reg: '%1' }
+body: |
+ bb.0.entry:
+ liveins: $edi, $esi
+
+ %1:gr32 = COPY $esi
+ %0:gr32 = COPY $edi
+ %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags
+ %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags
+ $eax = COPY %3
+ RET 0, $eax
+
+---
+name: simple_function
+alignment: 16
+tracksRegLiveness: true
+body: |
+ bb.0.entry:
+ RET 0
+
+# CHECK: MIR2Vec embeddings for machine function add_function:
+# CHECK: Function vector: [ 19.20 19.80 20.40 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: add_function:entry:
+# CHECK-NEXT: [ 19.20 19.80 20.40 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: %1:gr32 = COPY $esi
+# CHECK-NEXT: [ 5.20 5.30 5.40 ]
+# CHECK-NEXT: Machine instruction: %0:gr32 = COPY $edi
+# CHECK-NEXT: [ 5.20 5.30 5.40 ]
+# CHECK-NEXT: Machine instruction: %2:gr32 = nsw ADD32rr %0:gr32(tied-def 0), %1:gr32, implicit-def dead $eflags
+# CHECK-NEXT: [ 1.30 1.40 1.50 ]
+# CHECK-NEXT: Machine instruction: %3:gr32 = ADD32rr %2:gr32(tied-def 0), %2:gr32, implicit-def dead $eflags
+# CHECK-NEXT: [ 1.30 1.40 1.50 ]
+# CHECK-NEXT: Machine instruction: $eax = COPY %3:gr32
+# CHECK-NEXT: [ 5.20 5.30 5.40 ]
+# CHECK-NEXT: Machine instruction: RET 0, $eax
+# CHECK-NEXT: [ 1.00 1.10 1.20 ]
+
+# CHECK: MIR2Vec embeddings for machine function simple_function:
+# CHECK-NEXT:Function vector: [ 1.00 1.10 1.20 ]
+# CHECK-NEXT: Machine basic block vectors:
+# CHECK-NEXT: Machine basic block: simple_function:entry:
+# CHECK-NEXT: [ 1.00 1.10 1.20 ]
+# CHECK-NEXT: Machine instruction vectors:
+# CHECK-NEXT: Machine instruction: RET 0
+# CHECK-NEXT: [ 1.00 1.10 1.20 ]
\ No newline at end of file
diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp
index f04b256e2e6c9..f4441ccb896b1 100644
--- a/llvm/tools/llc/llc.cpp
+++ b/llvm/tools/llc/llc.cpp
@@ -172,6 +172,11 @@ static cl::opt<bool>
cl::desc("Print MIR2Vec vocabulary contents"),
cl::init(false));
+static cl::opt<bool>
+ PrintMIR2Vec("print-mir2vec", cl::Hidden,
+ cl::desc("Print MIR2Vec embeddings for functions"),
+ cl::init(false));
+
static cl::list<std::string> IncludeDirs("I", cl::desc("include search path"));
static cl::opt<bool> RemarksWithHotness(
@@ -776,6 +781,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
}
+ // Add MIR2Vec printer if requested
+ if (PrintMIR2Vec) {
+ PM.add(createMIR2VecPrinterLegacyPass(errs()));
+ }
+
PM.add(createFreeMachineFunctionPass());
} else {
if (Target->addPassesToEmitFile(PM, *OS, DwoOut ? &DwoOut->os() : nullptr,
@@ -789,6 +799,11 @@ static int compileModule(char **argv, LLVMContext &Context) {
if (PrintMIR2VecVocab) {
PM.add(createMIR2VecVocabPrinterLegacyPass(errs()));
}
+
+ // Add MIR2Vec printer if requested
+ if (PrintMIR2Vec) {
+ PM.add(createMIR2VecPrinterLegacyPass(errs()));
+ }
}
Target->getObjFileLowering()->Initialize(MMIWP->getMMI().getContext(),
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 11222b4d02fa3..8cd9d5ac9f6be 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -82,6 +82,9 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
return;
}
+ // Set the data layout to match the target machine
+ M->setDataLayout(TM->createDataLayout());
+
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
Function *F =
@@ -96,16 +99,27 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
}
void TearDown() override { TII = nullptr; }
-};
-// Function to find an opcode by name
-static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
- for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
- if (TII->getName(Opcode) == Name)
- return Opcode;
+ // Find an opcode by name
+ int findOpcodeByName(StringRef Name) {
+ for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+ if (TII->getName(Opcode) == Name)
+ return Opcode;
+ }
+ return -1; // Not found
}
- return -1; // Not found
-}
+
+ // Create a vocabulary with specific opcodes and embeddings
+ Expected<MIRVocabulary>
+ createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes,
+ unsigned dimension = 2) {
+ assert(TII && "TargetInstrInfo not initialized");
+ VocabMap VMap;
+ for (const auto &[name, value] : opcodes)
+ VMap[name] = Embedding(dimension, value);
+ return MIRVocabulary::create(std::move(VMap), *TII);
+ }
+};
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
@@ -118,10 +132,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
- VMap["ADD"] = Val;
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -156,16 +168,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
- int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+ int Add32rrOpcode = findOpcodeByName("ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
- int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+ int Sub32rrOpcode = findOpcodeByName("SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
- int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+ int Mov32rrOpcode = findOpcodeByName("MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
@@ -178,9 +190,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
- VMap["ADD"] = Embedding(64, 1.0f);
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -189,8 +199,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
-
- EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test across multiple runs
@@ -202,11 +210,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- VocabMap VMap;
- VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
- VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
-
- auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -243,4 +247,247 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
}
}
+// Fixture for embedding related tests
+class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
+protected:
+ std::unique_ptr<MachineModuleInfo> MMI;
+ MachineFunction *MF = nullptr;
+
+ void SetUp() override {
+ MIR2VecVocabTestFixture::SetUp();
+
+ // Create a dummy function for MachineFunction
+ FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
+ Function *F =
+ Function::Create(FT, Function::ExternalLinkage, "test", M.get());
+
+ MMI = std::make_unique<MachineModuleInfo>(TM.get());
+ MF = &MMI->getOrCreateMachineFunction(*F);
+ }
+
+ void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
+
+ // Create a machine instruction
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
+ const MCInstrDesc &Desc = TII->get(Opcode);
+ // Create instruction - operands don't affect opcode-based embeddings
+ MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
+ return MI;
+ }
+
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
+ const char *OpcodeName) {
+ int Opcode = findOpcodeByName(OpcodeName);
+ if (Opcode == -1)
+ return nullptr;
+ return createMachineInstr(MBB, Opcode);
+ }
+
+ void createMachineInstrs(MachineBasicBlock &MBB,
+ std::initializer_list<const char *> Opcodes) {
+ for (const char *OpcodeName : Opcodes) {
+ MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
+ ASSERT_TRUE(MI != nullptr);
+ }
+ }
+};
+
+// Test factory method for creating embedder
+TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
+ EXPECT_FALSE(static_cast<bool>(Result));
+}
+
+// Test SymbolicMIREmbedder with simple target opcodes
+TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
+ // Create a test vocabulary with specific values
+ auto VocabOrErr = createTestVocab(
+ {
+ {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
+ {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0]
+ {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0]
+ },
+ 4);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+ // Create a basic block using fixture's MF
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Use real X86 opcodes that should exist and not be pseudo
+ auto NoopInst = createMachineInstr(*MBB, "NOOP");
+ ASSERT_TRUE(NoopInst != nullptr);
+
+ auto RetInst = createMachineInstr(*MBB, "RET64");
+ ASSERT_TRUE(RetInst != nullptr);
+
+ auto TrapInst = createMachineInstr(*MBB, "TRAP");
+ ASSERT_TRUE(TrapInst != nullptr);
+
+ // Verify these are not pseudo instructions
+ ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
+ ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
+ ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto NoopEmb = Embedder->getMInstVector(*NoopInst);
+ auto RetEmb = Embedder->getMInstVector(*RetInst);
+ auto TrapEmb = Embedder->getMInstVector(*TrapInst);
+
+ // Verify embeddings match expected values (accounting for weight scaling)
+ float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line
+ EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
+ EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
+
+ // Test basic block embedding (should be sum of instruction embeddings)
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+
+ // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
+ // weight = [6, 6, 6, 6] * weight
+ Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector));
+
+ // Test function embedding (should equal MBB embedding since we have one MBB)
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
+}
+
+// Test embedder with multiple basic blocks
+TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
+ // Create a test vocabulary
+ auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create two basic blocks using fixture's MF
+ MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
+ MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB1);
+ MF->push_back(MBB2);
+
+ createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
+ createMachineInstr(*MBB2, "TRAP");
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test basic block embeddings
+ auto MBB1Vector = Embedder->getMBBVector(*MBB1);
+ auto MBB2Vector = Embedder->getMBBVector(*MBB2);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
+ Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // BB2: TRAP = [2, 2] * weight
+ Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector));
+
+ // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
+ // Function embedding should be just the first BB embedding as the second BB
+ // is unreachable
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // Add a branch from BB1 to BB2 to make both reachable; now function embedding
+ // should be MBB1 + MBB2
+ MBB1->addSuccessor(MBB2);
+ auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings
+ Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector;
+ EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector));
+}
+
+// Test embedder with empty basic block
+TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
+
+ // Create an empty basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create embedder
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Embedder = SymbolicMIREmbedder::create(*MF, V);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test that empty BB has zero embedding
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 0.0f);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+
+ // Function embedding should also be zero
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test embedder with opcodes not in vocabulary
+TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
+ // Create a test vocabulary with limited entries
+ // SUB is intentionally not included
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create a basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Find opcodes
+ int AddOpcode = findOpcodeByName("ADD32rr");
+ int SubOpcode = findOpcodeByName("SUB32rr");
+
+ ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
+ ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
+
+ // Create instructions
+ MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
+ MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto AddVector = Embedder->getMInstVector(*AddInstr);
+ auto SubVector = Embedder->getMInstVector(*SubInstr);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // ADD should have the embedding from vocabulary
+ EXPECT_TRUE(
+ AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
+
+ // SUB should have zero embedding (not in vocabulary)
+ EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f)));
+
+ // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
+ // 0.0] = [1.0, 1.0] * weight
+ const auto &MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+}
} // namespace
More information about the llvm-commits
mailing list