[llvm] 3581e9b - [NFC][IR2Vec] Refactoring for Stateless Embedding Computation (#141811)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 12:19:05 PDT 2025


Author: S. VenkataKeerthy
Date: 2025-05-28T12:19:02-07:00
New Revision: 3581e9bb4c7c37a1a277322d5389d4b11be0ac49

URL: https://github.com/llvm/llvm-project/commit/3581e9bb4c7c37a1a277322d5389d4b11be0ac49
DIFF: https://github.com/llvm/llvm-project/commit/3581e9bb4c7c37a1a277322d5389d4b11be0ac49.diff

LOG: [NFC][IR2Vec] Refactoring for Stateless Embedding Computation (#141811)

Currently, users have to invoke two APIs: `computeEmbeddings()` followed
by getters to access the embeddings. This PR refactors the code to
reduce this *stateful* access of APIs. Users can now directly invoke
getters; Internally, getters would compute the embeddings.

Added: 
    

Modified: 
    llvm/docs/MLGO.rst
    llvm/include/llvm/Analysis/IR2Vec.h
    llvm/lib/Analysis/IR2Vec.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index fa4b02cb11be7..549d8369d648d 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
       std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
-   Call ``computeEmbeddings()`` on the embedder instance to compute the 
-   embeddings. Then the embeddings can be accessed using 
diff erent getter 
-   methods. Currently, ``Embedder`` can generate embeddings at three levels:
-   Instructions, Basic Blocks, and Functions.
+   Call ``getFunctionVector()`` to get the embedding for the function. 
 
-   .. code-block:: c++
+  .. code-block:: c++
 
-      Emb->computeEmbeddings();
       const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
-      const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
-      const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
-
-      // Example: Iterate over instruction embeddings
-      for (const auto &Entry : InstVecMap) {
-        const Instruction *Inst = Entry.getFirst();
-        const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
-        // Use Inst and InstEmbedding
-      }
+
+   Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
+   Basic Blocks, and Functions. Appropriate getters are provided to access the
+   embeddings at these levels.
+
+  .. note::
+
+    The validity of ``Embedder`` instance (and the embeddings it generates) is
+    tied to the function it is associated with remains unchanged. If the function
+    is modified, the embeddings may become stale and should be recomputed accordingly.
 
 4. **Working with Embeddings:**
    Embeddings are represented as ``std::vector<double>``. These

diff  --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..43c95c5e89aed 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -80,12 +80,17 @@ class Embedder {
 
   // Utility maps - these are used to store the vector representations of
   // instructions, basic blocks and functions.
-  Embedding FuncVector;
-  BBEmbeddingsMap BBVecMap;
-  InstEmbeddingsMap InstVecMap;
+  mutable Embedding FuncVector;
+  mutable BBEmbeddingsMap BBVecMap;
+  mutable InstEmbeddingsMap InstVecMap;
 
   Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
 
+  /// Helper function to compute embeddings. It generates embeddings for all
+  /// the instructions and basic blocks in the function F. Logic of computing
+  /// the embeddings is specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings() const = 0;
+
   /// Lookup vocabulary for a given Key. If the key is not found, it returns a
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
 public:
   virtual ~Embedder() = default;
 
-  /// Top level function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F. Logic of computing
-  /// the embeddings is specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings() = 0;
-
   /// Factory method to create an Embedder object.
   static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
                                                     const Function &F,
                                                     const Vocab &Vocabulary,
                                                     unsigned Dimension);
 
-  /// Returns a map containing instructions and the corresponding embeddings.
-  const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
+  /// Returns a map containing instructions and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const InstEmbeddingsMap &getInstVecMap() const;
 
-  /// Returns a map containing basic block and the corresponding embeddings.
-  const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+  /// Returns a map containing basic block and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const BBEmbeddingsMap &getBBVecMap() const;
 
-  /// Returns the embedding for the current function.
-  const Embedding &getFunctionVector() const { return FuncVector; }
+  /// Computes and returns the embedding for the current function.
+  const Embedding &getFunctionVector() const;
 };
 
 /// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,7 +131,7 @@ class Embedder {
 class SymbolicEmbedder : public Embedder {
 private:
   /// Utility function to compute the embedding for a given basic block.
-  Embedding computeBB2Vec(const BasicBlock &BB);
+  Embedding computeBB2Vec(const BasicBlock &BB) const;
 
   /// Utility function to compute the embedding for a given type.
   Embedding getTypeEmbedding(const Type *Ty) const;
@@ -135,13 +139,14 @@ class SymbolicEmbedder : public Embedder {
   /// Utility function to compute the embedding for a given operand.
   Embedding getOperandEmbedding(const Value *Op) const;
 
+  void computeEmbeddings() const override;
+
 public:
   SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
                    unsigned Dimension)
       : Embedder(F, Vocabulary, Dimension) {
     FuncVector = Embedding(Dimension, 0);
   }
-  void computeEmbeddings() override;
 };
 
 } // namespace ir2vec

diff  --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..5f3114dcdeeaa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
   return Vec;
 }
 
+const InstEmbeddingsMap &Embedder::getInstVecMap() const {
+  if (InstVecMap.empty())
+    computeEmbeddings();
+  return InstVecMap;
+}
+
+const BBEmbeddingsMap &Embedder::getBBVecMap() const {
+  if (BBVecMap.empty())
+    computeEmbeddings();
+  return BBVecMap;
+}
+
+const Embedding &Embedder::getFunctionVector() const {
+  // Currently, we always (re)compute the embeddings for the function.
+  // This is cheaper than caching the vector.
+  computeEmbeddings();
+  return FuncVector;
+}
+
 #define RETURN_LOOKUP_IF(CONDITION, KEY_STR)                                   \
   if (CONDITION)                                                               \
     return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
 
 #undef RETURN_LOOKUP_IF
 
-void SymbolicEmbedder::computeEmbeddings() {
+void SymbolicEmbedder::computeEmbeddings() const {
   if (F.isDeclaration())
     return;
   for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
   }
 }
 
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
   for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
     }
 
     std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-    Emb->computeEmbeddings();
 
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";


        


More information about the llvm-commits mailing list