[llvm-branch-commits] [llvm] [IR2Vec] Simplifying creation of Embedder (PR #143999)

S. VenkataKeerthy via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jun 13 10:47:01 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143999

>From 0d921416a0f81e5634705dc9dfc5363d721a55bf Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 12 Jun 2025 23:54:10 +0000
Subject: [PATCH] Simplifying creation of Embedder

---
 llvm/docs/MLGO.rst                            |  7 +--
 llvm/include/llvm/Analysis/IR2Vec.h           |  4 +-
 .../Analysis/FunctionPropertiesAnalysis.cpp   | 10 ++---
 llvm/lib/Analysis/IR2Vec.cpp                  | 17 +++----
 .../FunctionPropertiesAnalysisTest.cpp        |  7 ++-
 llvm/unittests/Analysis/IR2VecTest.cpp        | 44 +++++++------------
 6 files changed, 33 insertions(+), 56 deletions(-)

diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
 
       // Assuming F is an llvm::Function&
       // For example, using IR2VecKind::Symbolic:
-      Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+      std::unique_ptr<ir2vec::Embedder> Emb =
           ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
 
-      if (auto Err = EmbOrErr.takeError()) {
-        // Handle error in embedder creation
-        return;
-      }
-      std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
    Call ``getFunctionVector()`` to get the embedding for the function. 
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
-  static Expected<std::unique_ptr<Embedder>>
-  create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+  static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+                                          const Vocab &Vocabulary);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
     // We instantiate the IR2Vec embedder each time, as having an unique
     // pointer to the embedder as member of the class would make it
     // non-copyable. Instantiating the embedder in itself is not costly.
-    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+    auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
                                              *BB.getParent(), *IR2VecVocab);
-    if (Error Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
-                                  EI.message());
-      });
+    if (!Embedder) {
+      BB.getContext().emitError("Error creating IR2Vec embeddings");
       return;
     }
-    auto Embedder = std::move(*EmbOrErr);
     const auto &BBEmbedding = Embedder->getBBVector(BB);
     // Subtract BBEmbedding from Function embedding if the direction is -1,
     // and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index de9c2db9531e8..308c3d86a7668 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
       Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
       TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
 
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+                                           const Vocab &Vocabulary) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
     return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
   }
-  return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+  llvm_unreachable("Unknown IR2Vec kind");
+  return nullptr;
 }
 
 // FIXME: Currently lookups are string based. Use numeric Keys
@@ -388,17 +389,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
 
   auto Vocab = IR2VecVocabResult.getVocabulary();
   for (Function &F : M) {
-    Expected<std::unique_ptr<Embedder>> EmbOrErr =
+    std::unique_ptr<Embedder> Emb =
         Embedder::create(IR2VecKind::Symbolic, F, Vocab);
-    if (auto Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
-      });
+    if (!Emb) {
+      OS << "Error creating IR2Vec embeddings \n";
       continue;
     }
 
-    std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";
     Emb->getFunctionVector().print(OS);
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
   }
 
   std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
-    auto EmbResult =
-        ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
-    EXPECT_TRUE(static_cast<bool>(EmbResult));
-    return std::move(*EmbResult);
+    auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    EXPECT_TRUE(static_cast<bool>(Emb));
+    return std::move(Emb);
   }
 };
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  EXPECT_TRUE(static_cast<bool>(Result));
-
-  auto *Emb = Result->get();
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   EXPECT_NE(Emb, nullptr);
 }
 
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+               "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
   auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
   EXPECT_FALSE(static_cast<bool>(Result));
-
-  std::string ErrMsg;
-  llvm::handleAllErrors(
-      Result.takeError(),
-      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
-  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
 }
 
 TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
   Instruction *AddInst = nullptr;
   Instruction *RetInst = nullptr;
 
-  float OriginalOpcWeight = ::OpcWeight;
-  float OriginalTypeWeight = ::TypeWeight;
-  float OriginalArgWeight = ::ArgWeight;
-
   void SetUp() override {
     V = {{"add", {1.0, 2.0}},
          {"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
 };
 
 TEST_F(IR2VecTestFixture, GetInstVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &InstMap = Emb->getInstVecMap();
 
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBMap = Emb->getBBVecMap();
 
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBVec = Emb->getBBVector(*BB);
 
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
 }
 
 TEST_F(IR2VecTestFixture, GetFunctionVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &FuncVec = Emb->getFunctionVector();
 



More information about the llvm-branch-commits mailing list