[llvm] [IR2Vec] Restrict caching only to Flow-Aware computation (PR #162559)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 9 15:16:03 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162559
>From e4c6990185c714c2538d90ea790eb1f6cedec7d5 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 18:20:17 +0000
Subject: [PATCH 1/4] IR2Vec Flow-aware fix
---
llvm/lib/Analysis/IR2Vec.cpp | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 688535161d4b9..1794a604b991d 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -239,10 +239,21 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
- assert(DefIt != InstVecMap.end() &&
- "Instruction should have been processed before its operands");
- ArgEmb += DefIt->second;
- continue;
+ // Fixme (#159171): Ideally we should never miss an instruction
+ // embedding here.
+ // But when we have cyclic dependencies (e.g., phi
+ // nodes), we might miss the embedding. In such cases, we fall back to
+ // using the vocabulary embedding. This can be fixed by iterating to a
+ // fixed-point, or by using a simple solver for the set of simultaneous
+ // equations.
+ // Another case when we might miss an instruction embedding is when
+ // the operand instruction is in a different basic block that has not
+ // been processed yet. This can be fixed by processing the basic blocks
+ // in a topological order.
+ if (DefIt != InstVecMap.end())
+ ArgEmb += DefIt->second;
+ else
+ ArgEmb += Vocab[*Op];
}
// If the operand is not defined by an instruction, we use the vocabulary
else {
>From 727a6b861eecc40d4b3797ea525879731ee8851c Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 21:36:30 +0000
Subject: [PATCH 2/4] Restrict caching only to Flow-Aware computation
---
llvm/docs/MLGO.rst | 2 +-
llvm/include/llvm/Analysis/IR2Vec.h | 53 ++++---
llvm/lib/Analysis/IR2Vec.cpp | 182 +++++++++--------------
llvm/test/Analysis/IR2Vec/unreachable.ll | 8 +-
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 16 +-
llvm/unittests/Analysis/IR2VecTest.cpp | 86 ++---------
6 files changed, 123 insertions(+), 224 deletions(-)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 965a21b8c84b8..bf3de11a2640e 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
.. code-block:: c++
- const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
+ ir2vec::Embedding FuncVector = Emb->getFunctionVector();
Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
Basic Blocks, and Functions. Appropriate getters are provided to access the
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 81409df7337c5..9be8899d1d4e5 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -533,21 +533,20 @@ class Embedder {
/// in the IR instructions to generate the vector representation.
const float OpcWeight, TypeWeight, ArgWeight;
- // Utility maps - these are used to store the vector representations of
- // instructions, basic blocks and functions.
- mutable Embedding FuncVector;
- mutable BBEmbeddingsMap BBVecMap;
- mutable InstEmbeddingsMap InstVecMap;
-
- LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
+ LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
+ : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
+ OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight),
+ ArgWeight(ir2vec::ArgWeight) {}
- /// Function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F.
- void computeEmbeddings() const;
+ /// Function to compute embeddings.
+ Embedding computeEmbeddings() const;
/// Function to compute the embedding for a given basic block.
+ Embedding computeEmbeddings(const BasicBlock &BB) const;
+
+ /// Function to compute the embedding for a given instruction.
/// Specific to the kind of embeddings being computed.
- virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
+ virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
public:
virtual ~Embedder() = default;
@@ -556,23 +555,20 @@ class Embedder {
LLVM_ABI static std::unique_ptr<Embedder>
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
- /// Returns a map containing instructions and the corresponding embeddings for
- /// the function F if it has been computed. If not, it computes the embeddings
- /// for the function and returns the map.
- LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
-
- /// Returns a map containing basic block and the corresponding embeddings for
- /// the function F if it has been computed. If not, it computes the embeddings
- /// for the function and returns the map.
- LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
+ /// Computes and returns the embedding for a given instruction in the function
+ /// F
+ LLVM_ABI Embedding getInstVector(const Instruction &I) const {
+ return computeEmbeddings(I);
+ }
- /// Returns the embedding for a given basic block in the function F if it has
- /// been computed. If not, it computes the embedding for the basic block and
- /// returns it.
- LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
+ /// Computes and returns the embedding for a given basic block in the function
+ /// F
+ LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const {
+ return computeEmbeddings(BB);
+ }
/// Computes and returns the embedding for the current function.
- LLVM_ABI const Embedding &getFunctionVector() const;
+ LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
};
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -580,7 +576,7 @@ class Embedder {
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
- void computeEmbeddings(const BasicBlock &BB) const override;
+ Embedding computeEmbeddings(const Instruction &I) const override;
public:
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
@@ -592,7 +588,10 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
/// embeddings, and additionally capture the flow information in the IR.
class LLVM_ABI FlowAwareEmbedder : public Embedder {
private:
- void computeEmbeddings(const BasicBlock &BB) const override;
+ // FlowAware embeddings would benefit from caching instruction embeddings as
+ // they are reused while computing the embeddings of other instructions.
+ mutable InstEmbeddingsMap InstVecMap;
+ Embedding computeEmbeddings(const Instruction &I) const override;
public:
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 1794a604b991d..6713cbe6c9a90 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
-Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
- : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension)) {}
-
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
@@ -169,110 +164,83 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
return nullptr;
}
-const InstEmbeddingsMap &Embedder::getInstVecMap() const {
- if (InstVecMap.empty())
- computeEmbeddings();
- return InstVecMap;
-}
-
-const BBEmbeddingsMap &Embedder::getBBVecMap() const {
- if (BBVecMap.empty())
- computeEmbeddings();
- return BBVecMap;
-}
-
-const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
- auto It = BBVecMap.find(&BB);
- if (It != BBVecMap.end())
- return It->second;
- computeEmbeddings(BB);
- return BBVecMap[&BB];
-}
+Embedding Embedder::computeEmbeddings() const {
+ Embedding FuncVector(Dimension, 0);
-const Embedding &Embedder::getFunctionVector() const {
- // Currently, we always (re)compute the embeddings for the function.
- // This is cheaper than caching the vector.
- computeEmbeddings();
- return FuncVector;
-}
-
-void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
- return;
-
- FuncVector = Embedding(Dimension, 0.0);
+ return FuncVector;
// Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
- }
-}
-
-void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
- Embedding BBVector(Dimension, 0);
-
- // We consider only the non-debug and non-pseudo instructions
- for (const auto &I : BB.instructionsWithoutDebug()) {
- Embedding ArgEmb(Dimension, 0);
- for (const auto &Op : I.operands())
- ArgEmb += Vocab[*Op];
- auto InstVector =
- Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
- if (const auto *IC = dyn_cast<CmpInst>(&I))
- InstVector += Vocab[IC->getPredicate()];
- InstVecMap[&I] = InstVector;
- BBVector += InstVector;
- }
- BBVecMap[&BB] = BBVector;
+ for (const BasicBlock *BB : depth_first(&F))
+ FuncVector += computeEmbeddings(*BB);
+ return FuncVector;
}
-void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
+ for (const Instruction &I : BB.instructionsWithoutDebug())
+ BBVector += computeEmbeddings(I);
+ return BBVector;
+}
+
+Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
+ // Currently, we always (re)compute the embeddings for symbolic embedder.
+ // This is cheaper than caching the vectors.
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands())
+ ArgEmb += Vocab[*Op];
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
+ return InstVector;
+}
+
+Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
+ // If we have already computed the embedding for this instruction, return it
+ auto It = InstVecMap.find(&I);
+ if (It != InstVecMap.end())
+ return It->second;
- // We consider only the non-debug and non-pseudo instructions
- for (const auto &I : BB.instructionsWithoutDebug()) {
- // TODO: Handle call instructions differently.
- // For now, we treat them like other instructions
- Embedding ArgEmb(Dimension, 0);
- for (const auto &Op : I.operands()) {
- // If the operand is defined elsewhere, we use its embedding
- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
- auto DefIt = InstVecMap.find(DefInst);
- // Fixme (#159171): Ideally we should never miss an instruction
- // embedding here.
- // But when we have cyclic dependencies (e.g., phi
- // nodes), we might miss the embedding. In such cases, we fall back to
- // using the vocabulary embedding. This can be fixed by iterating to a
- // fixed-point, or by using a simple solver for the set of simultaneous
- // equations.
- // Another case when we might miss an instruction embedding is when
- // the operand instruction is in a different basic block that has not
- // been processed yet. This can be fixed by processing the basic blocks
- // in a topological order.
- if (DefIt != InstVecMap.end())
- ArgEmb += DefIt->second;
- else
- ArgEmb += Vocab[*Op];
- }
- // If the operand is not defined by an instruction, we use the vocabulary
- else {
- LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
- << *Op << "=" << Vocab[*Op][0] << "\n");
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ // Fixme (#159171): Ideally we should never miss an instruction
+ // embedding here.
+ // But when we have cyclic dependencies (e.g., phi
+ // nodes), we might miss the embedding. In such cases, we fall back to
+ // using the vocabulary embedding. This can be fixed by iterating to a
+ // fixed-point, or by using a simple solver for the set of simultaneous
+ // equations.
+ // Another case when we might miss an instruction embedding is when
+ // the operand instruction is in a different basic block that has not
+ // been processed yet. This can be fixed by processing the basic blocks
+ // in a topological order.
+ if (DefIt != InstVecMap.end())
+ ArgEmb += DefIt->second;
+ else
ArgEmb += Vocab[*Op];
- }
}
- // Create the instruction vector by combining opcode, type, and arguments
- // embeddings
- auto InstVector =
- Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
- // Add compare predicate embedding as an additional operand if applicable
- if (const auto *IC = dyn_cast<CmpInst>(&I))
- InstVector += Vocab[IC->getPredicate()];
- InstVecMap[&I] = InstVector;
- BBVector += InstVector;
+ // If the operand is not defined by an instruction, we use the
+ // vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[*Op][0] << "\n");
+ ArgEmb += Vocab[*Op];
+ }
}
- BBVecMap[&BB] = BBVector;
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
+ InstVecMap[&I] = InstVector;
+ return InstVector;
}
// ==----------------------------------------------------------------------===//
@@ -695,25 +663,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
Emb->getFunctionVector().print(OS);
OS << "Basic block vectors:\n";
- const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
- auto It = BBMap.find(&BB);
- if (It != BBMap.end()) {
- OS << "Basic block: " << BB.getName() << ":\n";
- It->second.print(OS);
- }
+ OS << "Basic block: " << BB.getName() << ":\n";
+ Emb->getBBVector(BB).print(OS);
}
OS << "Instruction vectors:\n";
- const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
- auto It = InstMap.find(&I);
- if (It != InstMap.end()) {
- OS << "Instruction: ";
- I.print(OS);
- It->second.print(OS);
- }
+ OS << "Instruction: ";
+ I.print(OS);
+ Emb->getInstVector(I).print(OS);
}
}
}
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index 9be0ee1c2de7a..627e2c9ac6b2d 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -30,13 +30,17 @@ return: ; preds = %if.else, %if.then
%4 = load i32, ptr %retval, align 4
ret i32 %4
}
-
-; CHECK: Basic block vectors:
+; We'll get individual basic block embeddings for all blocks in the function.
+; But unreachable blocks are not counted for computing the function embedding.
+; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ]
+; CHECK-NEXT: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
+; CHECK-NEXT: Basic block: unreachable:
+; CHECK-NEXT: [ 101.00 103.00 105.00 ]
; CHECK-NEXT: Basic block: return:
; CHECK-NEXT: [ 95.00 97.00 99.00 ]
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 434449c7c5117..1031932116c1e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -253,25 +253,17 @@ class IR2VecTool {
break;
}
case BasicBlockLevel: {
- const auto &BBVecMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
- auto It = BBVecMap.find(&BB);
- if (It != BBVecMap.end()) {
- OS << BB.getName() << ":";
- It->second.print(OS);
- }
+ OS << BB.getName() << ":";
+ Emb->getBBVector(BB).print(OS);
}
break;
}
case InstructionLevel: {
- const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
- auto It = InstMap.find(&I);
- if (It != InstMap.end()) {
- I.print(OS);
- It->second.print(OS);
- }
+ I.print(OS);
+ Emb->getInstVector(I).print(OS);
}
}
break;
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 40b4aa21f2b46..24059b4fb9f69 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,7 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- void computeEmbeddings(const BasicBlock &BB) const override {}
+ Embedding computeEmbeddings(const Instruction &I) const override {}
};
TEST(EmbeddingTest, ConstructorsAndAccessors) {
@@ -321,18 +321,12 @@ class IR2VecTestFixture : public ::testing::Test {
}
};
-TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
+TEST_F(IR2VecTestFixture, GetInstVec_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
- const auto &InstMap = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(AddInst));
- EXPECT_TRUE(InstMap.count(RetInst));
-
- const auto &AddEmb = InstMap.at(AddInst);
- const auto &RetEmb = InstMap.at(RetInst);
+ const auto &AddEmb = Emb->getInstVector(*AddInst);
+ const auto &RetEmb = Emb->getInstVector(*RetInst);
EXPECT_EQ(AddEmb.size(), 2u);
EXPECT_EQ(RetEmb.size(), 2u);
@@ -340,51 +334,17 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
}
-TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+TEST_F(IR2VecTestFixture, GetInstVec_FlowAware) {
auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
- const auto &InstMap = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(AddInst));
- EXPECT_TRUE(InstMap.count(RetInst));
-
- EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
- EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
-
- EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
- EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
- ASSERT_TRUE(static_cast<bool>(Emb));
-
- const auto &BBMap = Emb->getBBVecMap();
-
- EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(BB));
- EXPECT_EQ(BBMap.at(BB).size(), 2u);
-
- // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
- // {41.0, 41.0}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
- ASSERT_TRUE(static_cast<bool>(Emb));
-
- const auto &BBMap = Emb->getBBVecMap();
-
- EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(BB));
- EXPECT_EQ(BBMap.at(BB).size(), 2u);
+ const auto &AddEmb = Emb->getInstVector(*AddInst);
+ const auto &RetEmb = Emb->getInstVector(*RetInst);
+ EXPECT_EQ(AddEmb.size(), 2u);
+ EXPECT_EQ(RetEmb.size(), 2u);
- // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
- // {58.1, 58.1}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
+ EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 32.6)));
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,6 +354,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+ // {41.0, 41.0}
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
}
@@ -404,6 +366,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+ // {58.1, 58.1}
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
}
@@ -445,16 +409,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
-
- // Also check that instruction vectors remain consistent
- const auto &InstMap1 = Emb->getInstVecMap();
- const auto &InstMap2 = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap1.size(), InstMap2.size());
- for (const auto &[Inst, Vec1] : InstMap1) {
- ASSERT_TRUE(InstMap2.count(Inst));
- EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
- }
}
TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -472,16 +426,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
-
- // Also check that instruction vectors remain consistent
- const auto &InstMap1 = Emb->getInstVecMap();
- const auto &InstMap2 = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap1.size(), InstMap2.size());
- for (const auto &[Inst, Vec1] : InstMap1) {
- ASSERT_TRUE(InstMap2.count(Inst));
- EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
- }
}
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
>From d3f8c842ca3cb2e924674d1e8f661f9ed96eaf0c Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 21:51:25 +0000
Subject: [PATCH 3/4] Added invalidate()
---
llvm/include/llvm/Analysis/IR2Vec.h | 8 ++++++++
llvm/lib/Analysis/IR2Vec.cpp | 6 ++++--
llvm/unittests/Analysis/IR2VecTest.cpp | 8 ++++++++
3 files changed, 20 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9be8899d1d4e5..6bc51feb580d9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -569,6 +569,13 @@ class Embedder {
/// Computes and returns the embedding for the current function.
LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
+
+ /// Invalidate embeddings if cached. The embeddings may not be relevant
+ /// anymore when the IR changes due to transformations. In such cases, the
+ /// cached embeddings should be invalidated to ensure
+ /// correctness/recomputation. This is a no-op for SymbolicEmbedder but
+ /// removes all the cached entries in FlowAwareEmbedder.
+ virtual void invalidateEmbeddings() { return; }
};
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -596,6 +603,7 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
public:
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
: Embedder(F, Vocab) {}
+ void invalidateEmbeddings() override { InstVecMap.clear(); }
};
} // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 6713cbe6c9a90..85b5372c961c1 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -165,7 +165,7 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
}
Embedding Embedder::computeEmbeddings() const {
- Embedding FuncVector(Dimension, 0);
+ Embedding FuncVector(Dimension, 0.0);
if (F.isDeclaration())
return FuncVector;
@@ -178,7 +178,9 @@ Embedding Embedder::computeEmbeddings() const {
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
- for (const Instruction &I : BB.instructionsWithoutDebug())
+
+ // We consider only the non-debug and non-pseudo instructions
+ for (const auto &I : BB.instructionsWithoutDebug())
BBVector += computeEmbeddings(I);
return BBVector;
}
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 24059b4fb9f69..0ec1c11faf50f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -409,6 +409,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
+
+ Emb->invalidateEmbeddings();
+ const auto &FuncVec4 = Emb->getFunctionVector();
+ EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
}
TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -426,6 +430,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
+
+ Emb->invalidateEmbeddings();
+ const auto &FuncVec4 = Emb->getFunctionVector();
+ EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
}
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
>From eb0c96c4f88438e0530341a53e9c59361370b61e Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 22:15:12 +0000
Subject: [PATCH 4/4] Fix error in test
---
llvm/unittests/Analysis/IR2VecTest.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0ec1c11faf50f..8ffc5f61d5e55 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,9 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- Embedding computeEmbeddings(const Instruction &I) const override {}
+ Embedding computeEmbeddings(const Instruction &I) const override {
+ return Embedding();
+ }
};
TEST(EmbeddingTest, ConstructorsAndAccessors) {
More information about the llvm-commits
mailing list