[llvm] 741136a - [NFC][IR2Vec] Removing Dimension from `Embedder::Create` (#142486)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 2 15:05:14 PDT 2025


Author: S. VenkataKeerthy
Date: 2025-06-02T15:05:11-07:00
New Revision: 741136a8ac924462da0e786a209e1bd4b9b247c6

URL: https://github.com/llvm/llvm-project/commit/741136a8ac924462da0e786a209e1bd4b9b247c6
DIFF: https://github.com/llvm/llvm-project/commit/741136a8ac924462da0e786a209e1bd4b9b247c6.diff

LOG: [NFC][IR2Vec] Removing Dimension from `Embedder::Create` (#142486)

This PR removes the necessity to know the dimension of the embeddings while invoking `Embedder::Create`. Having the `Dimension` parameter introduces complexities in downstream consumers. 

(Tracking issue - #141817)

Added: 
    

Modified: 
    llvm/docs/MLGO.rst
    llvm/include/llvm/Analysis/IR2Vec.h
    llvm/lib/Analysis/IR2Vec.cpp
    llvm/unittests/Analysis/IR2VecTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 377c2aec44475..4f8fb3f59ca19 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
         return;
       }
       const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
-      unsigned Dimension = VocabRes.getDimension();
 
    Note that ``IR2VecVocabAnalysis`` pass is immutable.
 
@@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
       // Assuming F is an llvm::Function&
       // For example, using IR2VecKind::Symbolic:
       Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
-          ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
+          ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
 
       if (auto Err = EmbOrErr.takeError()) {
         // Handle error in embedder creation

diff  --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 288753b3b3b8f..9fd1b0ae8e248 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -84,7 +84,7 @@ class Embedder {
   mutable BBEmbeddingsMap BBVecMap;
   mutable InstEmbeddingsMap InstVecMap;
 
-  Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+  Embedder(const Function &F, const Vocab &Vocabulary);
 
   /// Helper function to compute embeddings. It generates embeddings for all
   /// the instructions and basic blocks in the function F. Logic of computing
@@ -110,10 +110,8 @@ class Embedder {
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
-  static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
-                                                    const Function &F,
-                                                    const Vocab &Vocabulary,
-                                                    unsigned Dimension);
+  static Expected<std::unique_ptr<Embedder>>
+  create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
@@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
   void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
-  SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
-                   unsigned Dimension)
-      : Embedder(F, Vocabulary, Dimension) {
+  SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
+      : Embedder(F, Vocabulary) {
     FuncVector = Embedding(Dimension, 0);
   }
 };

diff  --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 67af44dcac424..490db5fdcdf99 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
 
-Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
-                   unsigned Dimension)
-    : F(F), Vocabulary(Vocabulary), Dimension(Dimension),
-      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
+    : F(F), Vocabulary(Vocabulary),
+      Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
+      TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
 
-Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
-                                                     const Function &F,
-                                                     const Vocab &Vocabulary,
-                                                     unsigned Dimension) {
+Expected<std::unique_ptr<Embedder>>
+Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
-    return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
+    return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
   }
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
@@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
   assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
 
   auto Vocab = IR2VecVocabResult.getVocabulary();
-  auto Dim = IR2VecVocabResult.getDimension();
   for (Function &F : M) {
     Expected<std::unique_ptr<Embedder>> EmbOrErr =
-        Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
+        Embedder::create(IR2VecKind::Symbolic, F, Vocab);
     if (auto Err = EmbOrErr.takeError()) {
       handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
         OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";

diff  --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0158038b59b6c..9e47b2cd8bedd 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -28,8 +28,7 @@ namespace {
 
 class TestableEmbedder : public Embedder {
 public:
-  TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
-      : Embedder(F, V, Dim) {}
+  TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
@@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
   EXPECT_TRUE(static_cast<bool>(Result));
 
   auto *Emb = Result->get();
@@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
   // static_cast an invalid int to IR2VecKind
-  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
   EXPECT_FALSE(static_cast<bool>(Result));
 
   std::string ErrMsg;
@@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  TestableEmbedder E(*F, V, 2);
+  TestableEmbedder E(*F, V);
   auto V_foo = E.lookupVocab("foo");
   EXPECT_EQ(V_foo.size(), 2u);
   EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
@@ -190,7 +189,7 @@ struct GetterTestEnv {
     Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
     Ret = ReturnInst::Create(Ctx, Add, BB);
 
-    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
     EXPECT_TRUE(static_cast<bool>(Result));
     Emb = std::move(*Result);
   }


        


More information about the llvm-commits mailing list