[llvm] [MIR2Vec] Added create factory methods for Vocabulary (PR #162569)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 9 00:07:09 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162569

>From 1b9aab1424abdfaa59d65fefcad0753ad1f66e1a Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 23:10:15 +0000
Subject: [PATCH 1/3] Added create factory methods for MIR2Vec Vocabulary

---
 llvm/include/llvm/CodeGen/MIR2Vec.h           | 33 +++++------
 llvm/lib/CodeGen/MIR2Vec.cpp                  | 57 ++++++++++---------
 .../CodeGen/MIR2Vec/vocab-error-handling.ll   | 16 +++---
 llvm/unittests/CodeGen/MIR2VecTest.cpp        | 35 ++++++++++--
 4 files changed, 85 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..dbffede50df81 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -38,6 +38,7 @@
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorOr.h"
 #include <map>
 #include <set>
@@ -92,25 +93,12 @@ class MIRVocabulary {
   /// Get the string key for a vocabulary entry at the given position
   std::string getStringKey(unsigned Pos) const;
 
-  MIRVocabulary() = delete;
-  MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
-  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
-      : Storage(std::move(Storage)), TII(TII) {}
-
-  bool isValid() const {
-    return UniqueBaseOpcodeNames.size() > 0 &&
-           Layout.TotalEntries == Storage.size() && Storage.isValid();
-  }
-
   unsigned getDimension() const {
-    if (!isValid())
-      return 0;
     return Storage.getDimension();
   }
 
   // Accessor methods
   const Embedding &operator[](unsigned Opcode) const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
     return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
   }
@@ -118,20 +106,30 @@ class MIRVocabulary {
   // Iterator access
   using const_iterator = ir2vec::VocabStorage::const_iterator;
   const_iterator begin() const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     return Storage.begin();
   }
 
   const_iterator end() const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     return Storage.end();
   }
 
   /// Total number of entries in the vocabulary
   size_t getCanonicalSize() const {
-    assert(isValid() && "Invalid vocabulary");
     return Storage.size();
   }
+
+  MIRVocabulary() = delete;
+
+  /// Factory method to create MIRVocabulary from vocabulary map
+  static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
+  
+  /// Factory method to create MIRVocabulary from existing storage
+  static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+
+private:
+  MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
+  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+      : Storage(std::move(Storage)), TII(TII) {}
 };
 
 } // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
 
   StringRef getPassName() const override;
   Error readVocabulary();
-  void emitError(Error Err, LLVMContext &Ctx);
 
 protected:
   void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
 public:
   static char ID;
   MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
-  mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
+  Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
 };
 
 /// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..669c11d5f739c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
-                             const TargetInstrInfo *TII)
-    : TII(*TII) {
-  // Fixme: Use static factory methods for creating vocabularies instead of
-  // public constructors
-  // Early return for invalid inputs - creates empty/invalid vocabulary
-  if (!TII || OpcodeEntries.empty())
-    return;
-
+                             const TargetInstrInfo &TII)
+    : TII(TII) {
   buildCanonicalOpcodeMapping();
 
   unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
   Layout.TotalEntries = Storage.size();
 }
 
+Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
+                                              const TargetInstrInfo &TII) {
+  if (Entries.empty())
+    return createStringError(errc::invalid_argument,
+                             "Empty vocabulary entries provided");
+
+  return MIRVocabulary(std::move(Entries), TII);
+}
+
+Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
+                                              const TargetInstrInfo &TII) {
+  if (!Storage.isValid())
+    return createStringError(errc::invalid_argument,
+                             "Invalid vocabulary storage provided");
+
+  return MIRVocabulary(std::move(Storage), TII);
+}
+
 std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
   // Extract base instruction name using regex to capture letters and
   // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
 }
 
 unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
-  assert(isValid() && "MIR2Vec Vocabulary is invalid");
   auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
   return getCanonicalIndexForBaseName(BaseOpcode);
 }
 
 std::string MIRVocabulary::getStringKey(unsigned Pos) const {
-  assert(isValid() && "MIR2Vec Vocabulary is invalid");
   assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
 
   // For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
   return Error::success();
 }
 
-void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
-  Ctx.emitError(toString(std::move(Err)));
-}
-
-mir2vec::MIRVocabulary
+Expected<mir2vec::MIRVocabulary>
 MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
   if (StrVocabMap.empty()) {
     if (Error Err = readVocabulary()) {
-      emitError(std::move(Err), M.getContext());
-      return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+      return std::move(Err);
     }
   }
 
@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
 
     if (auto *MF = MMI.getMachineFunction(F)) {
       const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
-      return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
+      return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
     }
   }
 
-  // No machine functions available - return invalid vocabulary
-  emitError(make_error<StringError>("No machine functions found in module",
-                                    inconvertibleErrorCode()),
-            M.getContext());
-  return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+  // No machine functions available - return error
+  return createStringError(errc::invalid_argument,
+                           "No machine functions found in module");
 }
 
 //===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
 
 bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
   auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
-  auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
+  auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
 
-  if (!MIR2VecVocab.isValid()) {
-    OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
+  if (!MIR2VecVocabOrErr) {
+    OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
+       << toString(MIR2VecVocabOrErr.takeError()) << "\n";
     return false;
   }
 
+  auto &MIR2VecVocab = *MIR2VecVocabOrErr;
   unsigned Pos = 0;
   for (const auto &Entry : MIR2VecVocab) {
     OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
index 1da516a6cd3b9..80b4048cea0c3 100644
--- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
+++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
@@ -1,15 +1,15 @@
 ; REQUIRES: x86_64-linux
-; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
+; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
 
 define dso_local void @test() {
   entry:
     ret void
 }
 
-; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
-; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
-; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
-; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
+; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
+; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
+; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..269e3b515c6fc 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
 #include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   VocabMap VMap;
   Embedding Val = Embedding(64, 1.0f);
   VMap["ADD"] = Val;
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+      << "Failed to create vocabulary: "
+      << toString(TestVocabOrErr.takeError());
+  auto &TestVocab = *TestVocabOrErr;
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
   VocabMap VMap;
   VMap["ADD"] = Embedding(64, 1.0f);
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+      << "Failed to create vocabulary: "
+      << toString(TestVocabOrErr.takeError());
+  auto &TestVocab = *TestVocabOrErr;
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
   VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
   VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
 
-  MIRVocabulary Vocab(std::move(VMap), TII);
-  EXPECT_TRUE(Vocab.isValid());
+  auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &Vocab = *VocabOrErr;
   EXPECT_EQ(Vocab.getDimension(), 128u);
 
   // Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
   EXPECT_GT(Count, 0u);
 }
 
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+  VocabMap EmptyVMap;
+
+  auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+  EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+      << "Factory method should fail with empty vocabulary";
+
+  // Consume the error
+  if (!VocabOrErr) {
+    auto Err = VocabOrErr.takeError();
+    std::string ErrorMsg = toString(std::move(Err));
+    EXPECT_FALSE(ErrorMsg.empty());
+  }
+}
+
 } // namespace
\ No newline at end of file

>From 7214a4794cf3a4ec7fa508b1bb4757940c2ccb8b Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 23:27:34 +0000
Subject: [PATCH 2/3] Removed unused create and formatting fixes

---
 llvm/include/llvm/CodeGen/MIR2Vec.h    | 24 ++++++------------------
 llvm/lib/CodeGen/MIR2Vec.cpp           |  9 ---------
 llvm/unittests/CodeGen/MIR2VecTest.cpp |  2 +-
 3 files changed, 7 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index dbffede50df81..7b1b5d9aee15d 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -93,9 +93,7 @@ class MIRVocabulary {
   /// Get the string key for a vocabulary entry at the given position
   std::string getStringKey(unsigned Pos) const;
 
-  unsigned getDimension() const {
-    return Storage.getDimension();
-  }
+  unsigned getDimension() const { return Storage.getDimension(); }
 
   // Accessor methods
   const Embedding &operator[](unsigned Opcode) const {
@@ -105,31 +103,21 @@ class MIRVocabulary {
 
   // Iterator access
   using const_iterator = ir2vec::VocabStorage::const_iterator;
-  const_iterator begin() const {
-    return Storage.begin();
-  }
+  const_iterator begin() const { return Storage.begin(); }
 
-  const_iterator end() const {
-    return Storage.end();
-  }
+  const_iterator end() const { return Storage.end(); }
 
   /// Total number of entries in the vocabulary
-  size_t getCanonicalSize() const {
-    return Storage.size();
-  }
+  size_t getCanonicalSize() const { return Storage.size(); }
 
   MIRVocabulary() = delete;
 
   /// Factory method to create MIRVocabulary from vocabulary map
-  static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
-  
-  /// Factory method to create MIRVocabulary from existing storage
-  static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+  static Expected<MIRVocabulary> create(VocabMap &&Entries,
+                                        const TargetInstrInfo &TII);
 
 private:
   MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
-  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
-      : Storage(std::move(Storage)), TII(TII) {}
 };
 
 } // namespace mir2vec
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 669c11d5f739c..e85976547a2c2 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -70,15 +70,6 @@ Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
   return MIRVocabulary(std::move(Entries), TII);
 }
 
-Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
-                                              const TargetInstrInfo &TII) {
-  if (!Storage.isValid())
-    return createStringError(errc::invalid_argument,
-                             "Invalid vocabulary storage provided");
-
-  return MIRVocabulary(std::move(Storage), TII);
-}
-
 std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
   // Extract base instruction name using regex to capture letters and
   // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 269e3b515c6fc..6ce791695e3e4 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -241,4 +241,4 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
   }
 }
 
-} // namespace
\ No newline at end of file
+} // namespace

>From 43e7b9a0f29e6b7aa74f219b83c54f0d1cf8171d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 06:59:28 +0000
Subject: [PATCH 3/3] Addressing review comments

---
 llvm/unittests/CodeGen/MIR2VecTest.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 6ce791695e3e4..11222b4d02fa3 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -53,7 +53,7 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
   std::unique_ptr<LLVMContext> Ctx;
   std::unique_ptr<Module> M;
   std::unique_ptr<TargetMachine> TM;
-  const TargetInstrInfo *TII;
+  const TargetInstrInfo *TII = nullptr;
 
   static void SetUpTestCase() {
     InitializeAllTargets();
@@ -94,6 +94,8 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
       return;
     }
   }
+
+  void TearDown() override { TII = nullptr; }
 };
 
 // Function to find an opcode by name



More information about the llvm-commits mailing list