[llvm] [IR2Vec] Support for lazy computation of BB Embeddings (PR #141694)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 27 17:53:52 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

<details>
<summary>Changes</summary>

This PR exposes interfaces to compute embeddings at BB level. This would be necessary for delta patching the embeddings in MLInliner.

---
Full diff: https://github.com/llvm/llvm-project/pull/141694.diff


2 Files Affected:

- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+7-3) 
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+19-12) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..d13e470bd239a 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -105,6 +105,9 @@ class Embedder {
   /// the embeddings is specific to the kind of embeddings being computed.
   virtual void computeEmbeddings() = 0;
 
+  /// Function to compute the embedding for a given basic block.
+  virtual void computeEmbeddings(const BasicBlock &BB) = 0;
+
   /// Factory method to create an Embedder object.
   static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
                                                     const Function &F,
@@ -117,6 +120,9 @@ class Embedder {
   /// Returns a map containing basic block and the corresponding embeddings.
   const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
 
+  /// Returns the embedding for a given basic block.
+  Expected<const Embedding &> getBBVector(const BasicBlock &BB) const;
+
   /// Returns the embedding for the current function.
   const Embedding &getFunctionVector() const { return FuncVector; }
 };
@@ -126,9 +132,6 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class SymbolicEmbedder : public Embedder {
 private:
-  /// Utility function to compute the embedding for a given basic block.
-  Embedding computeBB2Vec(const BasicBlock &BB);
-
   /// Utility function to compute the embedding for a given type.
   Embedding getTypeEmbedding(const Type *Ty) const;
 
@@ -142,6 +145,7 @@ class SymbolicEmbedder : public Embedder {
     FuncVector = Embedding(Dimension, 0);
   }
   void computeEmbeddings() override;
+  void computeEmbeddings(const BasicBlock &BB) override;
 };
 
 } // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..1690a918cca89 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -76,6 +76,14 @@ Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
 
+Expected<const Embedding &> Embedder::getBBVector(const BasicBlock &BB) const {
+  auto It = BBVecMap.find(&BB);
+  if (It == BBVecMap.end())
+    return createStringError(inconvertibleErrorCode(),
+                             "BB embedding not computed");
+  return It->second;
+}
+
 void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
   std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
                  std::plus<double>());
@@ -132,17 +140,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
 
 #undef RETURN_LOOKUP_IF
 
-void SymbolicEmbedder::computeEmbeddings() {
-  if (F.isDeclaration())
-    return;
-  for (const auto &BB : F) {
-    auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
-    assert(WasInserted && "Basic block already exists in the map");
-    addVectors(FuncVector, It->second);
-  }
-}
-
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) {
   Embedding BBVector(Dimension, 0);
 
   for (const auto &I : BB) {
@@ -164,7 +162,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
     InstVecMap[&I] = InstVector;
     addVectors(BBVector, InstVector);
   }
-  return BBVector;
+  BBVecMap[&BB] = BBVector;
+}
+
+void SymbolicEmbedder::computeEmbeddings() {
+  if (F.isDeclaration())
+    return;
+  for (const auto &BB : F) {
+    computeEmbeddings(BB);
+    addVectors(FuncVector, BBVecMap[&BB]);
+  }
 }
 
 // ==----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/141694


More information about the llvm-commits mailing list