[llvm] [IR2Vec] Refactor MIR vocabulary to use opcode-based indexing (PR #161713)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 7 13:48:40 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/161713

>From 6a43c342f53d65b581b4d5a35f0988097ff8e5fd Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 2 Oct 2025 18:14:53 +0000
Subject: [PATCH] MIRVocabulary changes

---
 llvm/include/llvm/CodeGen/MIR2Vec.h    | 31 +++++++++-------
 llvm/lib/CodeGen/MIR2Vec.cpp           | 18 ++++++----
 llvm/unittests/CodeGen/MIR2VecTest.cpp | 50 ++++++++++++++++----------
 3 files changed, 62 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 0ccb24448a678..ea68b4594a2ad 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -8,8 +8,8 @@
 ///
 /// \file
 /// This file defines the MIR2Vec vocabulary
-/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
-/// for generating Machine IR embeddings, and related utilities.
+/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
+/// interface for generating Machine IR embeddings, and related utilities.
 ///
 /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
 /// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
     size_t TotalEntries = 0;
   } Layout;
 
+  enum class Section : unsigned { Opcodes = 0, MaxSections };
+
   ir2vec::VocabStorage Storage;
   mutable std::set<std::string> UniqueBaseOpcodeNames;
-  void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
-  void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
+  const TargetInstrInfo &TII;
+  void generateStorage(const VocabMap &OpcodeMap);
+  void buildCanonicalOpcodeMapping();
+
+  /// Get canonical index for a machine opcode
+  unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
 
 public:
-  /// Static helper method for extracting base opcode names (public for testing)
+  /// Static method for extracting base opcode names (public for testing)
   static std::string extractBaseOpcodeName(StringRef InstrName);
 
-  /// Helper method for getting canonical index for base name (public for
-  /// testing)
+  /// Get canonical index for base name (public for testing)
   unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
 
   /// Get the string key for a vocabulary entry at the given position
   std::string getStringKey(unsigned Pos) const;
 
-  MIRVocabulary() = default;
+  MIRVocabulary() = delete;
   MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
-  MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+      : Storage(std::move(Storage)), TII(TII) {}
 
   bool isValid() const {
     return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
   }
 
   // Accessor methods
-  const Embedding &operator[](unsigned Index) const {
+  const Embedding &operator[](unsigned Opcode) const {
     assert(isValid() && "MIR2Vec Vocabulary is invalid");
-    assert(Index < Layout.TotalEntries && "Index out of bounds");
-    // Fixme: For now, use section 0 for all entries
-    return Storage[0][Index];
+    unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
+    return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
   }
 
   // Iterator access
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 83c5646629b48..87565c0c77115 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,20 +49,21 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
-                             const TargetInstrInfo *TII) {
+                             const TargetInstrInfo *TII)
+    : TII(*TII) {
   // Fixme: Use static factory methods for creating vocabularies instead of
   // public constructors
   // Early return for invalid inputs - creates empty/invalid vocabulary
   if (!TII || OpcodeEntries.empty())
     return;
 
-  buildCanonicalOpcodeMapping(*TII);
+  buildCanonicalOpcodeMapping();
 
   unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
   assert(CanonicalOpcodeCount > 0 &&
          "No canonical opcodes found for target - invalid vocabulary");
   Layout.OperandBase = CanonicalOpcodeCount;
-  generateStorage(OpcodeEntries, *TII);
+  generateStorage(OpcodeEntries);
   Layout.TotalEntries = Storage.size();
 }
 
@@ -105,6 +106,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
   return std::distance(UniqueBaseOpcodeNames.begin(), It);
 }
 
+unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
+  assert(isValid() && "MIR2Vec Vocabulary is invalid");
+  auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
+  return getCanonicalIndexForBaseName(BaseOpcode);
+}
+
 std::string MIRVocabulary::getStringKey(unsigned Pos) const {
   assert(isValid() && "MIR2Vec Vocabulary is invalid");
   assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -121,8 +128,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
   return "";
 }
 
-void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
-                                    const TargetInstrInfo &TII) {
+void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
 
   // Helper for handling missing entities in the vocabulary.
   // Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
   Storage = ir2vec::VocabStorage(std::move(Sections));
 }
 
-void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
+void MIRVocabulary::buildCanonicalOpcodeMapping() {
   // Check if already built
   if (!UniqueBaseOpcodeNames.empty())
     return;
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 01f2eadbf4068..7b282c7abe68c 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -87,6 +87,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
   }
 };
 
+// Function to find an opcode by name
+static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
+  for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+    if (TII->getName(Opcode) == Name)
+      return Opcode;
+  }
+  return -1; // Not found
+}
+
 TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   // Test that same base opcodes get same canonical indices
   std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -98,10 +107,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
 
   // Create a MIRVocabulary instance to test the mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VM;
+  VocabMap VMap;
   Embedding Val = Embedding(64, 1.0f);
-  VM["ADD"] = Val;
-  MIRVocabulary TestVocab(std::move(VM), TII);
+  VMap["ADD"] = Val;
+  MIRVocabulary TestVocab(std::move(VMap), TII);
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -132,9 +141,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
             6880u); // X86 has >6880 unique base opcodes
 
   // Check that the embeddings for opcodes not in the vocab are zero vectors
-  EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
-  EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
-  EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
+  int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+  ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
+  EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
+
+  int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+  ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
+  EXPECT_TRUE(
+      TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
+
+  int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+  ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
+  EXPECT_TRUE(
+      TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
 }
 
 // Test deterministic mapping
@@ -144,9 +163,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
   // Create a MIRVocabulary instance to test deterministic mapping
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
-  VocabMap VM;
-  VM["ADD"] = Embedding(64, 1.0f);
-  MIRVocabulary TestVocab(std::move(VM), TII);
+  VocabMap VMap;
+  VMap["ADD"] = Embedding(64, 1.0f);
+  MIRVocabulary TestVocab(std::move(VMap), TII);
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -164,16 +183,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
 // Test MIRVocabulary construction
 TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
-  // Test empty MIRVocabulary
-  MIRVocabulary EmptyVocab;
-  EXPECT_FALSE(EmptyVocab.isValid());
-
-  // Test MIRVocabulary with embeddings via VocabMap
-  VocabMap VM;
-  VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
-  VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
+  VocabMap VMap;
+  VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
+  VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
 
-  MIRVocabulary Vocab(std::move(VM), TII);
+  MIRVocabulary Vocab(std::move(VMap), TII);
   EXPECT_TRUE(Vocab.isValid());
   EXPECT_EQ(Vocab.getDimension(), 128u);
 



More information about the llvm-commits mailing list