[llvm-branch-commits] [llvm] Simplifying creation of Embedder (PR #143999)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 12 16:54:35 PDT 2025
https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/143999
None
>From cc133a17f78f9ed2082930617ed4d94dbbf9bf97 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 12 Jun 2025 23:54:10 +0000
Subject: [PATCH] Simplifying creation of Embedder
---
llvm/docs/MLGO.rst | 7 +--
llvm/include/llvm/Analysis/IR2Vec.h | 4 +-
.../Analysis/FunctionPropertiesAnalysis.cpp | 10 ++---
llvm/lib/Analysis/IR2Vec.cpp | 19 ++++----
.../FunctionPropertiesAnalysisTest.cpp | 7 ++-
llvm/unittests/Analysis/IR2VecTest.cpp | 44 +++++++------------
6 files changed, 34 insertions(+), 57 deletions(-)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
- Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+ std::unique_ptr<ir2vec::Embedder> Emb =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- if (auto Err = EmbOrErr.takeError()) {
- // Handle error in embedder creation
- return;
- }
- std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
Call ``getFunctionVector()`` to get the embedding for the function.
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
- static Expected<std::unique_ptr<Embedder>>
- create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+ static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
- auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+ auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
- if (Error Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- BB.getContext().emitError("Error creating IR2Vec embeddings: " +
- EI.message());
- });
+ if (!Embedder) {
+ BB.getContext().emitError("Error creating IR2Vec embeddings");
return;
}
- auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51d3252d6606..68026618449d8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
- return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+ llvm_unreachable("Unknown IR2Vec kind");
+ return nullptr;
}
// FIXME: Currently lookups are string based. Use numeric Keys
@@ -389,17 +390,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
- Expected<std::unique_ptr<Embedder>> EmbOrErr =
+ std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
- if (auto Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
- });
+ if (!Emb) {
+ OS << "Error creating IR2Vec embeddings \n";
continue;
}
- std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);
@@ -442,4 +439,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
}
return PreservedAnalyses::all();
-}
\ No newline at end of file
+}
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
}
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
- auto EmbResult =
- ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- EXPECT_TRUE(static_cast<bool>(EmbResult));
- return std::move(*EmbResult);
+ auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ EXPECT_TRUE(static_cast<bool>(Emb));
+ return std::move(Emb);
}
};
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
-
- auto *Emb = Result->get();
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+ EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+ "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
-
- std::string ErrMsg;
- llvm::handleAllErrors(
- Result.takeError(),
- [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
- EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
}
TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;
- float OriginalOpcWeight = ::OpcWeight;
- float OriginalTypeWeight = ::TypeWeight;
- float OriginalArgWeight = ::ArgWeight;
-
void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
};
TEST_F(IR2VecTestFixture, GetInstVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
More information about the llvm-branch-commits
mailing list