[llvm] Lazy BB Embeddings (PR #142033)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 13:57:20 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/142033
>From b65c2d40224c949a2937331246afac89c5d70090 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 29 May 2025 20:51:57 +0000
Subject: [PATCH] Lazy BB Embeddings
---
llvm/include/llvm/Analysis/IR2Vec.h | 13 ++++++++---
llvm/lib/Analysis/IR2Vec.cpp | 31 ++++++++++++++++----------
llvm/unittests/Analysis/IR2VecTest.cpp | 10 +++++++++
3 files changed, 39 insertions(+), 15 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 43c95c5e89aed..288753b3b3b8f 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -91,6 +91,10 @@ class Embedder {
/// the embeddings is specific to the kind of embeddings being computed.
virtual void computeEmbeddings() const = 0;
+ /// Helper function to compute the embedding for a given basic block.
+ /// Specific to the kind of embeddings being computed.
+ virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
+
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
/// zero vector.
Embedding lookupVocab(const std::string &Key) const;
@@ -121,6 +125,11 @@ class Embedder {
/// for the function and returns the map.
const BBEmbeddingsMap &getBBVecMap() const;
+ /// Returns the embedding for a given basic block in the function F if it has
+ /// been computed. If not, it computes the embedding for the basic block and
+ /// returns it.
+ const Embedding &getBBVector(const BasicBlock &BB) const;
+
/// Computes and returns the embedding for the current function.
const Embedding &getFunctionVector() const;
};
@@ -130,9 +139,6 @@ class Embedder {
/// representations obtained from the Vocabulary.
class SymbolicEmbedder : public Embedder {
private:
- /// Utility function to compute the embedding for a given basic block.
- Embedding computeBB2Vec(const BasicBlock &BB) const;
-
/// Utility function to compute the embedding for a given type.
Embedding getTypeEmbedding(const Type *Ty) const;
@@ -140,6 +146,7 @@ class SymbolicEmbedder : public Embedder {
Embedding getOperandEmbedding(const Value *Op) const;
void computeEmbeddings() const override;
+ void computeEmbeddings(const BasicBlock &BB) const override;
public:
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 683f05d5beb04..67af44dcac424 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -116,6 +116,14 @@ const BBEmbeddingsMap &Embedder::getBBVecMap() const {
return BBVecMap;
}
+const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
+ auto It = BBVecMap.find(&BB);
+ if (It != BBVecMap.end())
+ return It->second;
+ computeEmbeddings(BB);
+ return BBVecMap[&BB];
+}
+
const Embedding &Embedder::getFunctionVector() const {
// Currently, we always (re)compute the embeddings for the function.
// This is cheaper than caching the vector.
@@ -152,17 +160,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
#undef RETURN_LOOKUP_IF
-void SymbolicEmbedder::computeEmbeddings() const {
- if (F.isDeclaration())
- return;
- for (const auto &BB : F) {
- auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
- assert(WasInserted && "Basic block already exists in the map");
- addVectors(FuncVector, It->second);
- }
-}
-
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
+void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
for (const auto &I : BB) {
@@ -184,7 +182,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
InstVecMap[&I] = InstVector;
addVectors(BBVector, InstVector);
}
- return BBVector;
+ BBVecMap[&BB] = BBVector;
+}
+
+void SymbolicEmbedder::computeEmbeddings() const {
+ if (F.isDeclaration())
+ return;
+ for (const auto &BB : F) {
+ computeEmbeddings(BB);
+ addVectors(FuncVector, BBVecMap[&BB]);
+ }
}
// ==----------------------------------------------------------------------===//
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 5fb4da9f5fe20..0158038b59b6c 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -31,6 +31,7 @@ class TestableEmbedder : public Embedder {
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
: Embedder(F, V, Dim) {}
void computeEmbeddings() const override {}
+ void computeEmbeddings(const BasicBlock &BB) const override {}
using Embedder::lookupVocab;
static void addVectors(Embedding &Dst, const Embedding &Src) {
Embedder::addVectors(Dst, Src);
@@ -229,6 +230,15 @@ TEST(IR2VecTest, GetBBVecMap) {
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
+TEST(IR2VecTest, GetBBVector) {
+ GetterTestEnv Env;
+ const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+
+ EXPECT_EQ(BBVec.size(), 2u);
+ EXPECT_THAT(BBVec,
+ ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
+}
+
TEST(IR2VecTest, GetFunctionVector) {
GetterTestEnv Env;
const auto &FuncVec = Env.Emb->getFunctionVector();
More information about the llvm-commits
mailing list