[llvm] [IR2Vec] Refactor MIR vocabulary to use opcode-based indexing (PR #161713)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 7 13:48:40 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/161713
>From 6a43c342f53d65b581b4d5a35f0988097ff8e5fd Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 2 Oct 2025 18:14:53 +0000
Subject: [PATCH] MIRVocabulary changes
---
llvm/include/llvm/CodeGen/MIR2Vec.h | 31 +++++++++-------
llvm/lib/CodeGen/MIR2Vec.cpp | 18 ++++++----
llvm/unittests/CodeGen/MIR2VecTest.cpp | 50 ++++++++++++++++----------
3 files changed, 62 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 0ccb24448a678..ea68b4594a2ad 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -8,8 +8,8 @@
///
/// \file
/// This file defines the MIR2Vec vocabulary
-/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
-/// for generating Machine IR embeddings, and related utilities.
+/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
+/// interface for generating Machine IR embeddings, and related utilities.
///
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
/// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
size_t TotalEntries = 0;
} Layout;
+ enum class Section : unsigned { Opcodes = 0, MaxSections };
+
ir2vec::VocabStorage Storage;
mutable std::set<std::string> UniqueBaseOpcodeNames;
- void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
- void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
+ const TargetInstrInfo &TII;
+ void generateStorage(const VocabMap &OpcodeMap);
+ void buildCanonicalOpcodeMapping();
+
+ /// Get canonical index for a machine opcode
+ unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
public:
- /// Static helper method for extracting base opcode names (public for testing)
+ /// Static method for extracting base opcode names (public for testing)
static std::string extractBaseOpcodeName(StringRef InstrName);
- /// Helper method for getting canonical index for base name (public for
- /// testing)
+ /// Get canonical index for base name (public for testing)
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
- MIRVocabulary() = default;
+ MIRVocabulary() = delete;
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
- MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+ MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+ : Storage(std::move(Storage)), TII(TII) {}
bool isValid() const {
return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
}
// Accessor methods
- const Embedding &operator[](unsigned Index) const {
+ const Embedding &operator[](unsigned Opcode) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
- assert(Index < Layout.TotalEntries && "Index out of bounds");
- // Fixme: For now, use section 0 for all entries
- return Storage[0][Index];
+ unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
+ return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}
// Iterator access
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 83c5646629b48..87565c0c77115 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,20 +49,21 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
- const TargetInstrInfo *TII) {
+ const TargetInstrInfo *TII)
+ : TII(*TII) {
// Fixme: Use static factory methods for creating vocabularies instead of
// public constructors
// Early return for invalid inputs - creates empty/invalid vocabulary
if (!TII || OpcodeEntries.empty())
return;
- buildCanonicalOpcodeMapping(*TII);
+ buildCanonicalOpcodeMapping();
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
assert(CanonicalOpcodeCount > 0 &&
"No canonical opcodes found for target - invalid vocabulary");
Layout.OperandBase = CanonicalOpcodeCount;
- generateStorage(OpcodeEntries, *TII);
+ generateStorage(OpcodeEntries);
Layout.TotalEntries = Storage.size();
}
@@ -105,6 +106,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
return std::distance(UniqueBaseOpcodeNames.begin(), It);
}
+unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
+ assert(isValid() && "MIR2Vec Vocabulary is invalid");
+ auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
+ return getCanonicalIndexForBaseName(BaseOpcode);
+}
+
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -121,8 +128,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
return "";
}
-void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
- const TargetInstrInfo &TII) {
+void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
Storage = ir2vec::VocabStorage(std::move(Sections));
}
-void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
+void MIRVocabulary::buildCanonicalOpcodeMapping() {
// Check if already built
if (!UniqueBaseOpcodeNames.empty())
return;
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 01f2eadbf4068..7b282c7abe68c 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -87,6 +87,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
}
};
+// Function to find an opcode by name
+static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
+ for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+ if (TII->getName(Opcode) == Name)
+ return Opcode;
+ }
+ return -1; // Not found
+}
+
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -98,10 +107,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VM;
+ VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
- VM["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VM), TII);
+ VMap["ADD"] = Val;
+ MIRVocabulary TestVocab(std::move(VMap), TII);
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -132,9 +141,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
- EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
- EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
- EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
+ int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+ ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
+ EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
+
+ int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+ ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
+ EXPECT_TRUE(
+ TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
+
+ int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+ ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
+ EXPECT_TRUE(
+ TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
}
// Test deterministic mapping
@@ -144,9 +163,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VM;
- VM["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VM), TII);
+ VocabMap VMap;
+ VMap["ADD"] = Embedding(64, 1.0f);
+ MIRVocabulary TestVocab(std::move(VMap), TII);
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -164,16 +183,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- // Test empty MIRVocabulary
- MIRVocabulary EmptyVocab;
- EXPECT_FALSE(EmptyVocab.isValid());
-
- // Test MIRVocabulary with embeddings via VocabMap
- VocabMap VM;
- VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
- VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
+ VocabMap VMap;
+ VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
+ VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VM), TII);
+ MIRVocabulary Vocab(std::move(VMap), TII);
EXPECT_TRUE(Vocab.isValid());
EXPECT_EQ(Vocab.getDimension(), 128u);
More information about the llvm-commits
mailing list