[llvm] [MIR2Vec] Added create factory methods for Vocabulary (PR #162569)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 8 16:12:32 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlgo
Author: S. VenkataKeerthy (svkeerthy)
<details>
<summary>Changes</summary>
Added factory methods for vocabulary creation. This also would fix UB issue introduced by #<!-- -->161713
---
Full diff: https://github.com/llvm/llvm-project/pull/162569.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/MIR2Vec.h (+15-18)
- (modified) llvm/lib/CodeGen/MIR2Vec.cpp (+31-26)
- (modified) llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll (+8-8)
- (modified) llvm/unittests/CodeGen/MIR2VecTest.cpp (+31-4)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..dbffede50df81 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -38,6 +38,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorOr.h"
#include <map>
#include <set>
@@ -92,25 +93,12 @@ class MIRVocabulary {
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
- MIRVocabulary() = delete;
- MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
- MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
- : Storage(std::move(Storage)), TII(TII) {}
-
- bool isValid() const {
- return UniqueBaseOpcodeNames.size() > 0 &&
- Layout.TotalEntries == Storage.size() && Storage.isValid();
- }
-
unsigned getDimension() const {
- if (!isValid())
- return 0;
return Storage.getDimension();
}
// Accessor methods
const Embedding &operator[](unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}
@@ -118,20 +106,30 @@ class MIRVocabulary {
// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.begin();
}
const_iterator end() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.end();
}
/// Total number of entries in the vocabulary
size_t getCanonicalSize() const {
- assert(isValid() && "Invalid vocabulary");
return Storage.size();
}
+
+ MIRVocabulary() = delete;
+
+ /// Factory method to create MIRVocabulary from vocabulary map
+ static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
+
+ /// Factory method to create MIRVocabulary from existing storage
+ static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+
+private:
+ MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
+ MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+ : Storage(std::move(Storage)), TII(TII) {}
};
} // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
StringRef getPassName() const override;
Error readVocabulary();
- void emitError(Error Err, LLVMContext &Ctx);
protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
public:
static char ID;
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
- mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
+ Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
};
/// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..669c11d5f739c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
- const TargetInstrInfo *TII)
- : TII(*TII) {
- // Fixme: Use static factory methods for creating vocabularies instead of
- // public constructors
- // Early return for invalid inputs - creates empty/invalid vocabulary
- if (!TII || OpcodeEntries.empty())
- return;
-
+ const TargetInstrInfo &TII)
+ : TII(TII) {
buildCanonicalOpcodeMapping();
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
Layout.TotalEntries = Storage.size();
}
+Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
+ const TargetInstrInfo &TII) {
+ if (Entries.empty())
+ return createStringError(errc::invalid_argument,
+ "Empty vocabulary entries provided");
+
+ return MIRVocabulary(std::move(Entries), TII);
+}
+
+Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
+ const TargetInstrInfo &TII) {
+ if (!Storage.isValid())
+ return createStringError(errc::invalid_argument,
+ "Invalid vocabulary storage provided");
+
+ return MIRVocabulary(std::move(Storage), TII);
+}
+
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
}
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
// For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
return Error::success();
}
-void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
- Ctx.emitError(toString(std::move(Err)));
-}
-
-mir2vec::MIRVocabulary
+Expected<mir2vec::MIRVocabulary>
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (StrVocabMap.empty()) {
if (Error Err = readVocabulary()) {
- emitError(std::move(Err), M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ return std::move(Err);
}
}
@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (auto *MF = MMI.getMachineFunction(F)) {
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
+ return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
}
}
- // No machine functions available - return invalid vocabulary
- emitError(make_error<StringError>("No machine functions found in module",
- inconvertibleErrorCode()),
- M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ // No machine functions available - return error
+ return createStringError(errc::invalid_argument,
+ "No machine functions found in module");
}
//===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
- auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
+ auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
- if (!MIR2VecVocab.isValid()) {
- OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
+ if (!MIR2VecVocabOrErr) {
+ OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
+ << toString(MIR2VecVocabOrErr.takeError()) << "\n";
return false;
}
+ auto &MIR2VecVocab = *MIR2VecVocabOrErr;
unsigned Pos = 0;
for (const auto &Entry : MIR2VecVocab) {
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
index 1da516a6cd3b9..80b4048cea0c3 100644
--- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
+++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
@@ -1,15 +1,15 @@
; REQUIRES: x86_64-linux
-; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
+; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
define dso_local void @test() {
entry:
ret void
}
-; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
-; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
-; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
-; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
+; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
+; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
+; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..269e3b515c6fc 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VMap), TII);
- EXPECT_TRUE(Vocab.isValid());
+ auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+ VocabMap EmptyVMap;
+
+ auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail with empty vocabulary";
+
+ // Consume the error
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
+ }
+}
+
} // namespace
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/162569
More information about the llvm-commits
mailing list