[llvm] [MIR2Vec] Added create factory methods for Vocabulary (PR #162569)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 9 00:07:09 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162569
>From 1b9aab1424abdfaa59d65fefcad0753ad1f66e1a Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 23:10:15 +0000
Subject: [PATCH 1/3] Added create factory methods for MIR2Vec Vocabulary
---
llvm/include/llvm/CodeGen/MIR2Vec.h | 33 +++++------
llvm/lib/CodeGen/MIR2Vec.cpp | 57 ++++++++++---------
.../CodeGen/MIR2Vec/vocab-error-handling.ll | 16 +++---
llvm/unittests/CodeGen/MIR2VecTest.cpp | 35 ++++++++++--
4 files changed, 85 insertions(+), 56 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..dbffede50df81 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -38,6 +38,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorOr.h"
#include <map>
#include <set>
@@ -92,25 +93,12 @@ class MIRVocabulary {
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
- MIRVocabulary() = delete;
- MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
- MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
- : Storage(std::move(Storage)), TII(TII) {}
-
- bool isValid() const {
- return UniqueBaseOpcodeNames.size() > 0 &&
- Layout.TotalEntries == Storage.size() && Storage.isValid();
- }
-
unsigned getDimension() const {
- if (!isValid())
- return 0;
return Storage.getDimension();
}
// Accessor methods
const Embedding &operator[](unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}
@@ -118,20 +106,30 @@ class MIRVocabulary {
// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.begin();
}
const_iterator end() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.end();
}
/// Total number of entries in the vocabulary
size_t getCanonicalSize() const {
- assert(isValid() && "Invalid vocabulary");
return Storage.size();
}
+
+ MIRVocabulary() = delete;
+
+ /// Factory method to create MIRVocabulary from vocabulary map
+ static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
+
+ /// Factory method to create MIRVocabulary from existing storage
+ static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+
+private:
+ MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
+ MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+ : Storage(std::move(Storage)), TII(TII) {}
};
} // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
StringRef getPassName() const override;
Error readVocabulary();
- void emitError(Error Err, LLVMContext &Ctx);
protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
public:
static char ID;
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
- mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
+ Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
};
/// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..669c11d5f739c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
- const TargetInstrInfo *TII)
- : TII(*TII) {
- // Fixme: Use static factory methods for creating vocabularies instead of
- // public constructors
- // Early return for invalid inputs - creates empty/invalid vocabulary
- if (!TII || OpcodeEntries.empty())
- return;
-
+ const TargetInstrInfo &TII)
+ : TII(TII) {
buildCanonicalOpcodeMapping();
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
Layout.TotalEntries = Storage.size();
}
+Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
+ const TargetInstrInfo &TII) {
+ if (Entries.empty())
+ return createStringError(errc::invalid_argument,
+ "Empty vocabulary entries provided");
+
+ return MIRVocabulary(std::move(Entries), TII);
+}
+
+Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
+ const TargetInstrInfo &TII) {
+ if (!Storage.isValid())
+ return createStringError(errc::invalid_argument,
+ "Invalid vocabulary storage provided");
+
+ return MIRVocabulary(std::move(Storage), TII);
+}
+
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
}
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
// For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
return Error::success();
}
-void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
- Ctx.emitError(toString(std::move(Err)));
-}
-
-mir2vec::MIRVocabulary
+Expected<mir2vec::MIRVocabulary>
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (StrVocabMap.empty()) {
if (Error Err = readVocabulary()) {
- emitError(std::move(Err), M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ return std::move(Err);
}
}
@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (auto *MF = MMI.getMachineFunction(F)) {
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
+ return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
}
}
- // No machine functions available - return invalid vocabulary
- emitError(make_error<StringError>("No machine functions found in module",
- inconvertibleErrorCode()),
- M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ // No machine functions available - return error
+ return createStringError(errc::invalid_argument,
+ "No machine functions found in module");
}
//===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
- auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
+ auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
- if (!MIR2VecVocab.isValid()) {
- OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
+ if (!MIR2VecVocabOrErr) {
+ OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
+ << toString(MIR2VecVocabOrErr.takeError()) << "\n";
return false;
}
+ auto &MIR2VecVocab = *MIR2VecVocabOrErr;
unsigned Pos = 0;
for (const auto &Entry : MIR2VecVocab) {
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
index 1da516a6cd3b9..80b4048cea0c3 100644
--- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
+++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
@@ -1,15 +1,15 @@
; REQUIRES: x86_64-linux
-; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
+; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
define dso_local void @test() {
entry:
ret void
}
-; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
-; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
-; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
-; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
+; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
+; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
+; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..269e3b515c6fc 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VMap), TII);
- EXPECT_TRUE(Vocab.isValid());
+ auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+ VocabMap EmptyVMap;
+
+ auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail with empty vocabulary";
+
+ // Consume the error
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
+ }
+}
+
} // namespace
\ No newline at end of file
>From 7214a4794cf3a4ec7fa508b1bb4757940c2ccb8b Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 23:27:34 +0000
Subject: [PATCH 2/3] Removed unused create and formatting fixes
---
llvm/include/llvm/CodeGen/MIR2Vec.h | 24 ++++++------------------
llvm/lib/CodeGen/MIR2Vec.cpp | 9 ---------
llvm/unittests/CodeGen/MIR2VecTest.cpp | 2 +-
3 files changed, 7 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index dbffede50df81..7b1b5d9aee15d 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -93,9 +93,7 @@ class MIRVocabulary {
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
- unsigned getDimension() const {
- return Storage.getDimension();
- }
+ unsigned getDimension() const { return Storage.getDimension(); }
// Accessor methods
const Embedding &operator[](unsigned Opcode) const {
@@ -105,31 +103,21 @@ class MIRVocabulary {
// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
- const_iterator begin() const {
- return Storage.begin();
- }
+ const_iterator begin() const { return Storage.begin(); }
- const_iterator end() const {
- return Storage.end();
- }
+ const_iterator end() const { return Storage.end(); }
/// Total number of entries in the vocabulary
- size_t getCanonicalSize() const {
- return Storage.size();
- }
+ size_t getCanonicalSize() const { return Storage.size(); }
MIRVocabulary() = delete;
/// Factory method to create MIRVocabulary from vocabulary map
- static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
-
- /// Factory method to create MIRVocabulary from existing storage
- static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+ static Expected<MIRVocabulary> create(VocabMap &&Entries,
+ const TargetInstrInfo &TII);
private:
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
- MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
- : Storage(std::move(Storage)), TII(TII) {}
};
} // namespace mir2vec
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 669c11d5f739c..e85976547a2c2 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -70,15 +70,6 @@ Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
return MIRVocabulary(std::move(Entries), TII);
}
-Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
- const TargetInstrInfo &TII) {
- if (!Storage.isValid())
- return createStringError(errc::invalid_argument,
- "Invalid vocabulary storage provided");
-
- return MIRVocabulary(std::move(Storage), TII);
-}
-
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 269e3b515c6fc..6ce791695e3e4 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -241,4 +241,4 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
}
}
-} // namespace
\ No newline at end of file
+} // namespace
>From 43e7b9a0f29e6b7aa74f219b83c54f0d1cf8171d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 06:59:28 +0000
Subject: [PATCH 3/3] Addressing review comments
---
llvm/unittests/CodeGen/MIR2VecTest.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 6ce791695e3e4..11222b4d02fa3 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -53,7 +53,7 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
- const TargetInstrInfo *TII;
+ const TargetInstrInfo *TII = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
@@ -94,6 +94,8 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
return;
}
}
+
+ void TearDown() override { TII = nullptr; }
};
// Function to find an opcode by name
More information about the llvm-commits
mailing list