[llvm] [NFC][IR2Vec] Refactoring for Stateless Embedding Computation (PR #141811)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 12:18:26 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/141811

>From 06b4d1b14d15e313c46581828c58c528b7b57719 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 28 May 2025 17:43:40 +0000
Subject: [PATCH] Reducing state

---
 llvm/docs/MLGO.rst                  | 27 ++++++++++-----------
 llvm/include/llvm/Analysis/IR2Vec.h | 37 ++++++++++++++++-------------
 llvm/lib/Analysis/IR2Vec.cpp        | 24 ++++++++++++++++---
 3 files changed, 54 insertions(+), 34 deletions(-)

diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index fa4b02cb11be7..549d8369d648d 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
       std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
-   Call ``computeEmbeddings()`` on the embedder instance to compute the 
-   embeddings. Then the embeddings can be accessed using different getter 
-   methods. Currently, ``Embedder`` can generate embeddings at three levels:
-   Instructions, Basic Blocks, and Functions.
+   Call ``getFunctionVector()`` to get the embedding for the function. 
 
-   .. code-block:: c++
+  .. code-block:: c++
 
-      Emb->computeEmbeddings();
       const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
-      const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
-      const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
-
-      // Example: Iterate over instruction embeddings
-      for (const auto &Entry : InstVecMap) {
-        const Instruction *Inst = Entry.getFirst();
-        const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
-        // Use Inst and InstEmbedding
-      }
+
+   Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
+   Basic Blocks, and Functions. Appropriate getters are provided to access the
+   embeddings at these levels.
+
+  .. note::
+
+    The validity of ``Embedder`` instance (and the embeddings it generates) is
+    tied to the function it is associated with remains unchanged. If the function
+    is modified, the embeddings may become stale and should be recomputed accordingly.
 
 4. **Working with Embeddings:**
    Embeddings are represented as ``std::vector<double>``. These
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..43c95c5e89aed 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -80,12 +80,17 @@ class Embedder {
 
   // Utility maps - these are used to store the vector representations of
   // instructions, basic blocks and functions.
-  Embedding FuncVector;
-  BBEmbeddingsMap BBVecMap;
-  InstEmbeddingsMap InstVecMap;
+  mutable Embedding FuncVector;
+  mutable BBEmbeddingsMap BBVecMap;
+  mutable InstEmbeddingsMap InstVecMap;
 
   Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
 
+  /// Helper function to compute embeddings. It generates embeddings for all
+  /// the instructions and basic blocks in the function F. Logic of computing
+  /// the embeddings is specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings() const = 0;
+
   /// Lookup vocabulary for a given Key. If the key is not found, it returns a
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
 public:
   virtual ~Embedder() = default;
 
-  /// Top level function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F. Logic of computing
-  /// the embeddings is specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings() = 0;
-
   /// Factory method to create an Embedder object.
   static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
                                                     const Function &F,
                                                     const Vocab &Vocabulary,
                                                     unsigned Dimension);
 
-  /// Returns a map containing instructions and the corresponding embeddings.
-  const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
+  /// Returns a map containing instructions and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const InstEmbeddingsMap &getInstVecMap() const;
 
-  /// Returns a map containing basic block and the corresponding embeddings.
-  const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+  /// Returns a map containing basic block and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const BBEmbeddingsMap &getBBVecMap() const;
 
-  /// Returns the embedding for the current function.
-  const Embedding &getFunctionVector() const { return FuncVector; }
+  /// Computes and returns the embedding for the current function.
+  const Embedding &getFunctionVector() const;
 };
 
 /// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,7 +131,7 @@ class Embedder {
 class SymbolicEmbedder : public Embedder {
 private:
   /// Utility function to compute the embedding for a given basic block.
-  Embedding computeBB2Vec(const BasicBlock &BB);
+  Embedding computeBB2Vec(const BasicBlock &BB) const;
 
   /// Utility function to compute the embedding for a given type.
   Embedding getTypeEmbedding(const Type *Ty) const;
@@ -135,13 +139,14 @@ class SymbolicEmbedder : public Embedder {
   /// Utility function to compute the embedding for a given operand.
   Embedding getOperandEmbedding(const Value *Op) const;
 
+  void computeEmbeddings() const override;
+
 public:
   SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
                    unsigned Dimension)
       : Embedder(F, Vocabulary, Dimension) {
     FuncVector = Embedding(Dimension, 0);
   }
-  void computeEmbeddings() override;
 };
 
 } // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..5f3114dcdeeaa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
   return Vec;
 }
 
+const InstEmbeddingsMap &Embedder::getInstVecMap() const {
+  if (InstVecMap.empty())
+    computeEmbeddings();
+  return InstVecMap;
+}
+
+const BBEmbeddingsMap &Embedder::getBBVecMap() const {
+  if (BBVecMap.empty())
+    computeEmbeddings();
+  return BBVecMap;
+}
+
+const Embedding &Embedder::getFunctionVector() const {
+  // Currently, we always (re)compute the embeddings for the function.
+  // This is cheaper than caching the vector.
+  computeEmbeddings();
+  return FuncVector;
+}
+
 #define RETURN_LOOKUP_IF(CONDITION, KEY_STR)                                   \
   if (CONDITION)                                                               \
     return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
 
 #undef RETURN_LOOKUP_IF
 
-void SymbolicEmbedder::computeEmbeddings() {
+void SymbolicEmbedder::computeEmbeddings() const {
   if (F.isDeclaration())
     return;
   for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
   }
 }
 
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
   for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
     }
 
     std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-    Emb->computeEmbeddings();
 
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";



More information about the llvm-commits mailing list