[llvm] Lazy BB Embeddings (PR #142033)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 13:57:20 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/142033

>From b65c2d40224c949a2937331246afac89c5d70090 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 29 May 2025 20:51:57 +0000
Subject: [PATCH] Lazy BB Embeddings

---
 llvm/include/llvm/Analysis/IR2Vec.h    | 13 ++++++++---
 llvm/lib/Analysis/IR2Vec.cpp           | 31 ++++++++++++++++----------
 llvm/unittests/Analysis/IR2VecTest.cpp | 10 +++++++++
 3 files changed, 39 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 43c95c5e89aed..288753b3b3b8f 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -91,6 +91,10 @@ class Embedder {
   /// the embeddings is specific to the kind of embeddings being computed.
   virtual void computeEmbeddings() const = 0;
 
+  /// Helper function to compute the embedding for a given basic block.
+  /// Specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
+
   /// Lookup vocabulary for a given Key. If the key is not found, it returns a
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
@@ -121,6 +125,11 @@ class Embedder {
   /// for the function and returns the map.
   const BBEmbeddingsMap &getBBVecMap() const;
 
+  /// Returns the embedding for a given basic block in the function F if it has
+  /// been computed. If not, it computes the embedding for the basic block and
+  /// returns it.
+  const Embedding &getBBVector(const BasicBlock &BB) const;
+
   /// Computes and returns the embedding for the current function.
   const Embedding &getFunctionVector() const;
 };
@@ -130,9 +139,6 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class SymbolicEmbedder : public Embedder {
 private:
-  /// Utility function to compute the embedding for a given basic block.
-  Embedding computeBB2Vec(const BasicBlock &BB) const;
-
   /// Utility function to compute the embedding for a given type.
   Embedding getTypeEmbedding(const Type *Ty) const;
 
@@ -140,6 +146,7 @@ class SymbolicEmbedder : public Embedder {
   Embedding getOperandEmbedding(const Value *Op) const;
 
   void computeEmbeddings() const override;
+  void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
   SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 683f05d5beb04..67af44dcac424 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -116,6 +116,14 @@ const BBEmbeddingsMap &Embedder::getBBVecMap() const {
   return BBVecMap;
 }
 
+const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
+  auto It = BBVecMap.find(&BB);
+  if (It != BBVecMap.end())
+    return It->second;
+  computeEmbeddings(BB);
+  return BBVecMap[&BB];
+}
+
 const Embedding &Embedder::getFunctionVector() const {
   // Currently, we always (re)compute the embeddings for the function.
   // This is cheaper than caching the vector.
@@ -152,17 +160,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
 
 #undef RETURN_LOOKUP_IF
 
-void SymbolicEmbedder::computeEmbeddings() const {
-  if (F.isDeclaration())
-    return;
-  for (const auto &BB : F) {
-    auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
-    assert(WasInserted && "Basic block already exists in the map");
-    addVectors(FuncVector, It->second);
-  }
-}
-
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
+void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
   for (const auto &I : BB) {
@@ -184,7 +182,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
     InstVecMap[&I] = InstVector;
     addVectors(BBVector, InstVector);
   }
-  return BBVector;
+  BBVecMap[&BB] = BBVector;
+}
+
+void SymbolicEmbedder::computeEmbeddings() const {
+  if (F.isDeclaration())
+    return;
+  for (const auto &BB : F) {
+    computeEmbeddings(BB);
+    addVectors(FuncVector, BBVecMap[&BB]);
+  }
 }
 
 // ==----------------------------------------------------------------------===//
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 5fb4da9f5fe20..0158038b59b6c 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -31,6 +31,7 @@ class TestableEmbedder : public Embedder {
   TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
       : Embedder(F, V, Dim) {}
   void computeEmbeddings() const override {}
+  void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
   static void addVectors(Embedding &Dst, const Embedding &Src) {
     Embedder::addVectors(Dst, Src);
@@ -229,6 +230,15 @@ TEST(IR2VecTest, GetBBVecMap) {
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
+TEST(IR2VecTest, GetBBVector) {
+  GetterTestEnv Env;
+  const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+
+  EXPECT_EQ(BBVec.size(), 2u);
+  EXPECT_THAT(BBVec,
+              ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
+}
+
 TEST(IR2VecTest, GetFunctionVector) {
   GetterTestEnv Env;
   const auto &FuncVec = Env.Emb->getFunctionVector();



More information about the llvm-commits mailing list