[llvm] [NFC][IR2Vec] Refactoring for Stateless Embedding Computation (PR #141811)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 12:11:57 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/141811
>From 5cb422ed6092241b37ef917f0cb2cc9bcd1a076d Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 28 May 2025 17:43:40 +0000
Subject: [PATCH] Reducing state
---
llvm/docs/MLGO.rst | 27 ++++++++++-----------
llvm/include/llvm/Analysis/IR2Vec.h | 37 ++++++++++++++++-------------
llvm/lib/Analysis/IR2Vec.cpp | 24 ++++++++++++++++---
3 files changed, 54 insertions(+), 34 deletions(-)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index fa4b02cb11be7..549d8369d648d 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
- Call ``computeEmbeddings()`` on the embedder instance to compute the
- embeddings. Then the embeddings can be accessed using different getter
- methods. Currently, ``Embedder`` can generate embeddings at three levels:
- Instructions, Basic Blocks, and Functions.
+ Call ``getFunctionVector()`` to get the embedding for the function.
- .. code-block:: c++
+ .. code-block:: c++
- Emb->computeEmbeddings();
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
- const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
- const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
-
- // Example: Iterate over instruction embeddings
- for (const auto &Entry : InstVecMap) {
- const Instruction *Inst = Entry.getFirst();
- const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
- // Use Inst and InstEmbedding
- }
+
+ Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
+ Basic Blocks, and Functions. Appropriate getters are provided to access the
+ embeddings at these levels.
+
+ .. note::
+
+ The validity of ``Embedder`` instance (and the embeddings it generates) is
+ tied to the function it is associated with remains unchanged. If the function
+ is modified, the embeddings may become stale and should be recomputed accordingly.
4. **Working with Embeddings:**
Embeddings are represented as ``std::vector<double>``. These
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..43c95c5e89aed 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -80,12 +80,17 @@ class Embedder {
// Utility maps - these are used to store the vector representations of
// instructions, basic blocks and functions.
- Embedding FuncVector;
- BBEmbeddingsMap BBVecMap;
- InstEmbeddingsMap InstVecMap;
+ mutable Embedding FuncVector;
+ mutable BBEmbeddingsMap BBVecMap;
+ mutable InstEmbeddingsMap InstVecMap;
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+ /// Helper function to compute embeddings. It generates embeddings for all
+ /// the instructions and basic blocks in the function F. Logic of computing
+ /// the embeddings is specific to the kind of embeddings being computed.
+ virtual void computeEmbeddings() const = 0;
+
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
/// zero vector.
Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
public:
virtual ~Embedder() = default;
- /// Top level function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F. Logic of computing
- /// the embeddings is specific to the kind of embeddings being computed.
- virtual void computeEmbeddings() = 0;
-
/// Factory method to create an Embedder object.
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
const Function &F,
const Vocab &Vocabulary,
unsigned Dimension);
- /// Returns a map containing instructions and the corresponding embeddings.
- const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
+ /// Returns a map containing instructions and the corresponding embeddings for
+ /// the function F if it has been computed. If not, it computes the embeddings
+ /// for the function and returns the map.
+ const InstEmbeddingsMap &getInstVecMap() const;
- /// Returns a map containing basic block and the corresponding embeddings.
- const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+ /// Returns a map containing basic block and the corresponding embeddings for
+ /// the function F if it has been computed. If not, it computes the embeddings
+ /// for the function and returns the map.
+ const BBEmbeddingsMap &getBBVecMap() const;
- /// Returns the embedding for the current function.
- const Embedding &getFunctionVector() const { return FuncVector; }
+ /// Computes and returns the embedding for the current function.
+ const Embedding &getFunctionVector() const;
};
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,7 +131,7 @@ class Embedder {
class SymbolicEmbedder : public Embedder {
private:
/// Utility function to compute the embedding for a given basic block.
- Embedding computeBB2Vec(const BasicBlock &BB);
+ Embedding computeBB2Vec(const BasicBlock &BB) const;
/// Utility function to compute the embedding for a given type.
Embedding getTypeEmbedding(const Type *Ty) const;
@@ -135,13 +139,14 @@ class SymbolicEmbedder : public Embedder {
/// Utility function to compute the embedding for a given operand.
Embedding getOperandEmbedding(const Value *Op) const;
+ void computeEmbeddings() const override;
+
public:
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
unsigned Dimension)
: Embedder(F, Vocabulary, Dimension) {
FuncVector = Embedding(Dimension, 0);
}
- void computeEmbeddings() override;
};
} // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..5f3114dcdeeaa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
return Vec;
}
+const InstEmbeddingsMap &Embedder::getInstVecMap() const {
+ if (InstVecMap.empty())
+ computeEmbeddings();
+ return InstVecMap;
+}
+
+const BBEmbeddingsMap &Embedder::getBBVecMap() const {
+ if (BBVecMap.empty())
+ computeEmbeddings();
+ return BBVecMap;
+}
+
+const Embedding &Embedder::getFunctionVector() const {
+ // Currently, we always (re)compute the embeddings for the function.
+ // This is cheaper than caching the vector.
+ computeEmbeddings();
+ return FuncVector;
+}
+
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
if (CONDITION) \
return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
#undef RETURN_LOOKUP_IF
-void SymbolicEmbedder::computeEmbeddings() {
+void SymbolicEmbedder::computeEmbeddings() const {
if (F.isDeclaration())
return;
for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
}
}
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
}
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
- Emb->computeEmbeddings();
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
More information about the llvm-commits
mailing list