[llvm] [IR2Vec] Support for lazy computation of BB Embeddings (PR #141694)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 27 17:53:52 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: S. VenkataKeerthy (svkeerthy)
<details>
<summary>Changes</summary>
This PR exposes interfaces to compute embeddings at BB level. This would be necessary for delta patching the embeddings in MLInliner.
---
Full diff: https://github.com/llvm/llvm-project/pull/141694.diff
2 Files Affected:
- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+7-3)
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+19-12)
``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..d13e470bd239a 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -105,6 +105,9 @@ class Embedder {
/// the embeddings is specific to the kind of embeddings being computed.
virtual void computeEmbeddings() = 0;
+ /// Function to compute the embedding for a given basic block.
+ virtual void computeEmbeddings(const BasicBlock &BB) = 0;
+
/// Factory method to create an Embedder object.
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
const Function &F,
@@ -117,6 +120,9 @@ class Embedder {
/// Returns a map containing basic block and the corresponding embeddings.
const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+ /// Returns the embedding for a given basic block.
+ Expected<const Embedding &> getBBVector(const BasicBlock &BB) const;
+
/// Returns the embedding for the current function.
const Embedding &getFunctionVector() const { return FuncVector; }
};
@@ -126,9 +132,6 @@ class Embedder {
/// representations obtained from the Vocabulary.
class SymbolicEmbedder : public Embedder {
private:
- /// Utility function to compute the embedding for a given basic block.
- Embedding computeBB2Vec(const BasicBlock &BB);
-
/// Utility function to compute the embedding for a given type.
Embedding getTypeEmbedding(const Type *Ty) const;
@@ -142,6 +145,7 @@ class SymbolicEmbedder : public Embedder {
FuncVector = Embedding(Dimension, 0);
}
void computeEmbeddings() override;
+ void computeEmbeddings(const BasicBlock &BB) override;
};
} // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..1690a918cca89 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -76,6 +76,14 @@ Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
}
+Expected<const Embedding &> Embedder::getBBVector(const BasicBlock &BB) const {
+ auto It = BBVecMap.find(&BB);
+ if (It == BBVecMap.end())
+ return createStringError(inconvertibleErrorCode(),
+ "BB embedding not computed");
+ return It->second;
+}
+
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
std::plus<double>());
@@ -132,17 +140,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
#undef RETURN_LOOKUP_IF
-void SymbolicEmbedder::computeEmbeddings() {
- if (F.isDeclaration())
- return;
- for (const auto &BB : F) {
- auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
- assert(WasInserted && "Basic block already exists in the map");
- addVectors(FuncVector, It->second);
- }
-}
-
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) {
Embedding BBVector(Dimension, 0);
for (const auto &I : BB) {
@@ -164,7 +162,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
InstVecMap[&I] = InstVector;
addVectors(BBVector, InstVector);
}
- return BBVector;
+ BBVecMap[&BB] = BBVector;
+}
+
+void SymbolicEmbedder::computeEmbeddings() {
+ if (F.isDeclaration())
+ return;
+ for (const auto &BB : F) {
+ computeEmbeddings(BB);
+ addVectors(FuncVector, BBVecMap[&BB]);
+ }
}
// ==----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/141694
More information about the llvm-commits
mailing list