[llvm] [IR2Vec] Restrict caching only to Flow-Aware computation (PR #162559)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 9 15:16:03 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162559

>From e4c6990185c714c2538d90ea790eb1f6cedec7d5 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 8 Oct 2025 18:20:17 +0000
Subject: [PATCH 1/4] IR2Vec Flow-aware fix

---
 llvm/lib/Analysis/IR2Vec.cpp | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 688535161d4b9..1794a604b991d 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -239,10 +239,21 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       // If the operand is defined elsewhere, we use its embedding
       if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
         auto DefIt = InstVecMap.find(DefInst);
-        assert(DefIt != InstVecMap.end() &&
-               "Instruction should have been processed before its operands");
-        ArgEmb += DefIt->second;
-        continue;
+        // Fixme (#159171): Ideally we should never miss an instruction
+        // embedding here.
+        // But when we have cyclic dependencies (e.g., phi
+        // nodes), we might miss the embedding. In such cases, we fall back to
+        // using the vocabulary embedding. This can be fixed by iterating to a
+        // fixed-point, or by using a simple solver for the set of simultaneous
+        // equations.
+        // Another case when we might miss an instruction embedding is when
+        // the operand instruction is in a different basic block that has not
+        // been processed yet. This can be fixed by processing the basic blocks
+        // in a topological order.
+        if (DefIt != InstVecMap.end())
+          ArgEmb += DefIt->second;
+        else
+          ArgEmb += Vocab[*Op];
       }
       // If the operand is not defined by an instruction, we use the vocabulary
       else {

>From 727a6b861eecc40d4b3797ea525879731ee8851c Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 21:36:30 +0000
Subject: [PATCH 2/4] Restrict caching only to Flow-Aware computation

---
 llvm/docs/MLGO.rst                       |   2 +-
 llvm/include/llvm/Analysis/IR2Vec.h      |  53 ++++---
 llvm/lib/Analysis/IR2Vec.cpp             | 182 +++++++++--------------
 llvm/test/Analysis/IR2Vec/unreachable.ll |   8 +-
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp   |  16 +-
 llvm/unittests/Analysis/IR2VecTest.cpp   |  86 ++---------
 6 files changed, 123 insertions(+), 224 deletions(-)

diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 965a21b8c84b8..bf3de11a2640e 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
 
    .. code-block:: c++
 
-    const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
+    ir2vec::Embedding FuncVector = Emb->getFunctionVector();
 
    Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
    Basic Blocks, and Functions. Appropriate getters are provided to access the
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 81409df7337c5..9be8899d1d4e5 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -533,21 +533,20 @@ class Embedder {
   /// in the IR instructions to generate the vector representation.
   const float OpcWeight, TypeWeight, ArgWeight;
 
-  // Utility maps - these are used to store the vector representations of
-  // instructions, basic blocks and functions.
-  mutable Embedding FuncVector;
-  mutable BBEmbeddingsMap BBVecMap;
-  mutable InstEmbeddingsMap InstVecMap;
-
-  LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
+  LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
+      : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
+        OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight),
+        ArgWeight(ir2vec::ArgWeight) {}
 
-  /// Function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F.
-  void computeEmbeddings() const;
+  /// Function to compute embeddings.
+  Embedding computeEmbeddings() const;
 
   /// Function to compute the embedding for a given basic block.
+  Embedding computeEmbeddings(const BasicBlock &BB) const;
+
+  /// Function to compute the embedding for a given instruction.
   /// Specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
+  virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
 
 public:
   virtual ~Embedder() = default;
@@ -556,23 +555,20 @@ class Embedder {
   LLVM_ABI static std::unique_ptr<Embedder>
   create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
 
-  /// Returns a map containing instructions and the corresponding embeddings for
-  /// the function F if it has been computed. If not, it computes the embeddings
-  /// for the function and returns the map.
-  LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
-
-  /// Returns a map containing basic block and the corresponding embeddings for
-  /// the function F if it has been computed. If not, it computes the embeddings
-  /// for the function and returns the map.
-  LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
+  /// Computes and returns the embedding for a given instruction in the function
+  /// F
+  LLVM_ABI Embedding getInstVector(const Instruction &I) const {
+    return computeEmbeddings(I);
+  }
 
-  /// Returns the embedding for a given basic block in the function F if it has
-  /// been computed. If not, it computes the embedding for the basic block and
-  /// returns it.
-  LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
+  /// Computes and returns the embedding for a given basic block in the function
+  /// F
+  LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const {
+    return computeEmbeddings(BB);
+  }
 
   /// Computes and returns the embedding for the current function.
-  LLVM_ABI const Embedding &getFunctionVector() const;
+  LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
 };
 
 /// Class for computing the Symbolic embeddings of IR2Vec.
@@ -580,7 +576,7 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class LLVM_ABI SymbolicEmbedder : public Embedder {
 private:
-  void computeEmbeddings(const BasicBlock &BB) const override;
+  Embedding computeEmbeddings(const Instruction &I) const override;
 
 public:
   SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
@@ -592,7 +588,10 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
 /// embeddings, and additionally capture the flow information in the IR.
 class LLVM_ABI FlowAwareEmbedder : public Embedder {
 private:
-  void computeEmbeddings(const BasicBlock &BB) const override;
+  // FlowAware embeddings would benefit from caching instruction embeddings as
+  // they are reused while computing the embeddings of other instructions.
+  mutable InstEmbeddingsMap InstVecMap;
+  Embedding computeEmbeddings(const Instruction &I) const override;
 
 public:
   FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 1794a604b991d..6713cbe6c9a90 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
 
-Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
-    : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
-      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
-      FuncVector(Embedding(Dimension)) {}
-
 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
                                            const Vocabulary &Vocab) {
   switch (Mode) {
@@ -169,110 +164,83 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
   return nullptr;
 }
 
-const InstEmbeddingsMap &Embedder::getInstVecMap() const {
-  if (InstVecMap.empty())
-    computeEmbeddings();
-  return InstVecMap;
-}
-
-const BBEmbeddingsMap &Embedder::getBBVecMap() const {
-  if (BBVecMap.empty())
-    computeEmbeddings();
-  return BBVecMap;
-}
-
-const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
-  auto It = BBVecMap.find(&BB);
-  if (It != BBVecMap.end())
-    return It->second;
-  computeEmbeddings(BB);
-  return BBVecMap[&BB];
-}
+Embedding Embedder::computeEmbeddings() const {
+  Embedding FuncVector(Dimension, 0);
 
-const Embedding &Embedder::getFunctionVector() const {
-  // Currently, we always (re)compute the embeddings for the function.
-  // This is cheaper than caching the vector.
-  computeEmbeddings();
-  return FuncVector;
-}
-
-void Embedder::computeEmbeddings() const {
   if (F.isDeclaration())
-    return;
-
-  FuncVector = Embedding(Dimension, 0.0);
+    return FuncVector;
 
   // Consider only the basic blocks that are reachable from entry
-  for (const BasicBlock *BB : depth_first(&F)) {
-    computeEmbeddings(*BB);
-    FuncVector += BBVecMap[BB];
-  }
-}
-
-void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
-  Embedding BBVector(Dimension, 0);
-
-  // We consider only the non-debug and non-pseudo instructions
-  for (const auto &I : BB.instructionsWithoutDebug()) {
-    Embedding ArgEmb(Dimension, 0);
-    for (const auto &Op : I.operands())
-      ArgEmb += Vocab[*Op];
-    auto InstVector =
-        Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
-    if (const auto *IC = dyn_cast<CmpInst>(&I))
-      InstVector += Vocab[IC->getPredicate()];
-    InstVecMap[&I] = InstVector;
-    BBVector += InstVector;
-  }
-  BBVecMap[&BB] = BBVector;
+  for (const BasicBlock *BB : depth_first(&F))
+    FuncVector += computeEmbeddings(*BB);
+  return FuncVector;
 }
 
-void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
+  for (const Instruction &I : BB.instructionsWithoutDebug())
+    BBVector += computeEmbeddings(I);
+  return BBVector;
+}
+
+Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
+  // Currently, we always (re)compute the embeddings for symbolic embedder.
+  // This is cheaper than caching the vectors.
+  Embedding ArgEmb(Dimension, 0);
+  for (const auto &Op : I.operands())
+    ArgEmb += Vocab[*Op];
+  auto InstVector =
+      Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+  if (const auto *IC = dyn_cast<CmpInst>(&I))
+    InstVector += Vocab[IC->getPredicate()];
+  return InstVector;
+}
+
+Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
+  // If we have already computed the embedding for this instruction, return it
+  auto It = InstVecMap.find(&I);
+  if (It != InstVecMap.end())
+    return It->second;
 
-  // We consider only the non-debug and non-pseudo instructions
-  for (const auto &I : BB.instructionsWithoutDebug()) {
-    // TODO: Handle call instructions differently.
-    // For now, we treat them like other instructions
-    Embedding ArgEmb(Dimension, 0);
-    for (const auto &Op : I.operands()) {
-      // If the operand is defined elsewhere, we use its embedding
-      if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
-        auto DefIt = InstVecMap.find(DefInst);
-        // Fixme (#159171): Ideally we should never miss an instruction
-        // embedding here.
-        // But when we have cyclic dependencies (e.g., phi
-        // nodes), we might miss the embedding. In such cases, we fall back to
-        // using the vocabulary embedding. This can be fixed by iterating to a
-        // fixed-point, or by using a simple solver for the set of simultaneous
-        // equations.
-        // Another case when we might miss an instruction embedding is when
-        // the operand instruction is in a different basic block that has not
-        // been processed yet. This can be fixed by processing the basic blocks
-        // in a topological order.
-        if (DefIt != InstVecMap.end())
-          ArgEmb += DefIt->second;
-        else
-          ArgEmb += Vocab[*Op];
-      }
-      // If the operand is not defined by an instruction, we use the vocabulary
-      else {
-        LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
-                          << *Op << "=" << Vocab[*Op][0] << "\n");
+  // TODO: Handle call instructions differently.
+  // For now, we treat them like other instructions
+  Embedding ArgEmb(Dimension, 0);
+  for (const auto &Op : I.operands()) {
+    // If the operand is defined elsewhere, we use its embedding
+    if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+      auto DefIt = InstVecMap.find(DefInst);
+      // Fixme (#159171): Ideally we should never miss an instruction
+      // embedding here.
+      // But when we have cyclic dependencies (e.g., phi
+      // nodes), we might miss the embedding. In such cases, we fall back to
+      // using the vocabulary embedding. This can be fixed by iterating to a
+      // fixed-point, or by using a simple solver for the set of simultaneous
+      // equations.
+      // Another case when we might miss an instruction embedding is when
+      // the operand instruction is in a different basic block that has not
+      // been processed yet. This can be fixed by processing the basic blocks
+      // in a topological order.
+      if (DefIt != InstVecMap.end())
+        ArgEmb += DefIt->second;
+      else
         ArgEmb += Vocab[*Op];
-      }
     }
-    // Create the instruction vector by combining opcode, type, and arguments
-    // embeddings
-    auto InstVector =
-        Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
-    // Add compare predicate embedding as an additional operand if applicable
-    if (const auto *IC = dyn_cast<CmpInst>(&I))
-      InstVector += Vocab[IC->getPredicate()];
-    InstVecMap[&I] = InstVector;
-    BBVector += InstVector;
+    // If the operand is not defined by an instruction, we use the
+    // vocabulary
+    else {
+      LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+                        << *Op << "=" << Vocab[*Op][0] << "\n");
+      ArgEmb += Vocab[*Op];
+    }
   }
-  BBVecMap[&BB] = BBVector;
+  // Create the instruction vector by combining opcode, type, and arguments
+  // embeddings
+  auto InstVector =
+      Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+  if (const auto *IC = dyn_cast<CmpInst>(&I))
+    InstVector += Vocab[IC->getPredicate()];
+  InstVecMap[&I] = InstVector;
+  return InstVector;
 }
 
 // ==----------------------------------------------------------------------===//
@@ -695,25 +663,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
     Emb->getFunctionVector().print(OS);
 
     OS << "Basic block vectors:\n";
-    const auto &BBMap = Emb->getBBVecMap();
     for (const BasicBlock &BB : F) {
-      auto It = BBMap.find(&BB);
-      if (It != BBMap.end()) {
-        OS << "Basic block: " << BB.getName() << ":\n";
-        It->second.print(OS);
-      }
+      OS << "Basic block: " << BB.getName() << ":\n";
+      Emb->getBBVector(BB).print(OS);
     }
 
     OS << "Instruction vectors:\n";
-    const auto &InstMap = Emb->getInstVecMap();
     for (const BasicBlock &BB : F) {
       for (const Instruction &I : BB) {
-        auto It = InstMap.find(&I);
-        if (It != InstMap.end()) {
-          OS << "Instruction: ";
-          I.print(OS);
-          It->second.print(OS);
-        }
+        OS << "Instruction: ";
+        I.print(OS);
+        Emb->getInstVector(I).print(OS);
       }
     }
   }
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index 9be0ee1c2de7a..627e2c9ac6b2d 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -30,13 +30,17 @@ return:                                           ; preds = %if.else, %if.then
   %4 = load i32, ptr %retval, align 4
   ret i32 %4
 }
-
-; CHECK: Basic block vectors:
+; We'll get individual basic block embeddings for all blocks in the function.
+; But unreachable blocks are not counted for computing the function embedding.
+; CHECK: Function vector:  [ 1301.20  1318.20  1335.20 ]
+; CHECK-NEXT: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
 ; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
+; CHECK-NEXT: Basic block: unreachable:
+; CHECK-NEXT:  [ 101.00  103.00  105.00 ]
 ; CHECK-NEXT: Basic block: return:
 ; CHECK-NEXT: [ 95.00  97.00  99.00 ]
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 434449c7c5117..1031932116c1e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -253,25 +253,17 @@ class IR2VecTool {
       break;
     }
     case BasicBlockLevel: {
-      const auto &BBVecMap = Emb->getBBVecMap();
       for (const BasicBlock &BB : F) {
-        auto It = BBVecMap.find(&BB);
-        if (It != BBVecMap.end()) {
-          OS << BB.getName() << ":";
-          It->second.print(OS);
-        }
+        OS << BB.getName() << ":";
+        Emb->getBBVector(BB).print(OS);
       }
       break;
     }
     case InstructionLevel: {
-      const auto &InstMap = Emb->getInstVecMap();
       for (const BasicBlock &BB : F) {
         for (const Instruction &I : BB) {
-          auto It = InstMap.find(&I);
-          if (It != InstMap.end()) {
-            I.print(OS);
-            It->second.print(OS);
-          }
+          I.print(OS);
+          Emb->getInstVector(I).print(OS);
         }
       }
       break;
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 40b4aa21f2b46..24059b4fb9f69 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,7 @@ namespace {
 class TestableEmbedder : public Embedder {
 public:
   TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
-  void computeEmbeddings(const BasicBlock &BB) const override {}
+  Embedding computeEmbeddings(const Instruction &I) const override {}
 };
 
 TEST(EmbeddingTest, ConstructorsAndAccessors) {
@@ -321,18 +321,12 @@ class IR2VecTestFixture : public ::testing::Test {
   }
 };
 
-TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
+TEST_F(IR2VecTestFixture, GetInstVec_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
-  const auto &InstMap = Emb->getInstVecMap();
-
-  EXPECT_EQ(InstMap.size(), 2u);
-  EXPECT_TRUE(InstMap.count(AddInst));
-  EXPECT_TRUE(InstMap.count(RetInst));
-
-  const auto &AddEmb = InstMap.at(AddInst);
-  const auto &RetEmb = InstMap.at(RetInst);
+  const auto &AddEmb = Emb->getInstVector(*AddInst);
+  const auto &RetEmb = Emb->getInstVector(*RetInst);
   EXPECT_EQ(AddEmb.size(), 2u);
   EXPECT_EQ(RetEmb.size(), 2u);
 
@@ -340,51 +334,17 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
   EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
 }
 
-TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+TEST_F(IR2VecTestFixture, GetInstVec_FlowAware) {
   auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
-  const auto &InstMap = Emb->getInstVecMap();
-
-  EXPECT_EQ(InstMap.size(), 2u);
-  EXPECT_TRUE(InstMap.count(AddInst));
-  EXPECT_TRUE(InstMap.count(RetInst));
-
-  EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
-  EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
-
-  EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
-  EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
-  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
-  ASSERT_TRUE(static_cast<bool>(Emb));
-
-  const auto &BBMap = Emb->getBBVecMap();
-
-  EXPECT_EQ(BBMap.size(), 1u);
-  EXPECT_TRUE(BBMap.count(BB));
-  EXPECT_EQ(BBMap.at(BB).size(), 2u);
-
-  // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
-  // {41.0, 41.0}
-  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
-  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
-  ASSERT_TRUE(static_cast<bool>(Emb));
-
-  const auto &BBMap = Emb->getBBVecMap();
-
-  EXPECT_EQ(BBMap.size(), 1u);
-  EXPECT_TRUE(BBMap.count(BB));
-  EXPECT_EQ(BBMap.at(BB).size(), 2u);
+  const auto &AddEmb = Emb->getInstVector(*AddInst);
+  const auto &RetEmb = Emb->getInstVector(*RetInst);
+  EXPECT_EQ(AddEmb.size(), 2u);
+  EXPECT_EQ(RetEmb.size(), 2u);
 
-  // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
-  // {58.1, 58.1}
-  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
+  EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+  EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 32.6)));
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,6 +354,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
   const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
+  // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+  // {41.0, 41.0}
   EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
 }
 
@@ -404,6 +366,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
   const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
+  // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+  // {58.1, 58.1}
   EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
 }
 
@@ -445,16 +409,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
   EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
-
-  // Also check that instruction vectors remain consistent
-  const auto &InstMap1 = Emb->getInstVecMap();
-  const auto &InstMap2 = Emb->getInstVecMap();
-
-  EXPECT_EQ(InstMap1.size(), InstMap2.size());
-  for (const auto &[Inst, Vec1] : InstMap1) {
-    ASSERT_TRUE(InstMap2.count(Inst));
-    EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
-  }
 }
 
 TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -472,16 +426,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
   EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
-
-  // Also check that instruction vectors remain consistent
-  const auto &InstMap1 = Emb->getInstVecMap();
-  const auto &InstMap2 = Emb->getInstVecMap();
-
-  EXPECT_EQ(InstMap1.size(), InstMap2.size());
-  for (const auto &[Inst, Vec1] : InstMap1) {
-    ASSERT_TRUE(InstMap2.count(Inst));
-    EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
-  }
 }
 
 static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;

>From d3f8c842ca3cb2e924674d1e8f661f9ed96eaf0c Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 21:51:25 +0000
Subject: [PATCH 3/4] Added invalidate()

---
 llvm/include/llvm/Analysis/IR2Vec.h    | 8 ++++++++
 llvm/lib/Analysis/IR2Vec.cpp           | 6 ++++--
 llvm/unittests/Analysis/IR2VecTest.cpp | 8 ++++++++
 3 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9be8899d1d4e5..6bc51feb580d9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -569,6 +569,13 @@ class Embedder {
 
   /// Computes and returns the embedding for the current function.
   LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
+
+  /// Invalidate embeddings if cached. The embeddings may not be relevant
+  /// anymore when the IR changes due to transformations. In such cases, the
+  /// cached embeddings should be invalidated to ensure
+  /// correctness/recomputation. This is a no-op for SymbolicEmbedder but
+  /// removes all the cached entries in FlowAwareEmbedder.
+  virtual void invalidateEmbeddings() { return; }
 };
 
 /// Class for computing the Symbolic embeddings of IR2Vec.
@@ -596,6 +603,7 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
 public:
   FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
       : Embedder(F, Vocab) {}
+  void invalidateEmbeddings() override { InstVecMap.clear(); }
 };
 
 } // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 6713cbe6c9a90..85b5372c961c1 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -165,7 +165,7 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
 }
 
 Embedding Embedder::computeEmbeddings() const {
-  Embedding FuncVector(Dimension, 0);
+  Embedding FuncVector(Dimension, 0.0);
 
   if (F.isDeclaration())
     return FuncVector;
@@ -178,7 +178,9 @@ Embedding Embedder::computeEmbeddings() const {
 
 Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
-  for (const Instruction &I : BB.instructionsWithoutDebug())
+
+  // We consider only the non-debug and non-pseudo instructions
+  for (const auto &I : BB.instructionsWithoutDebug())
     BBVector += computeEmbeddings(I);
   return BBVector;
 }
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 24059b4fb9f69..0ec1c11faf50f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -409,6 +409,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
   EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
+
+  Emb->invalidateEmbeddings();
+  const auto &FuncVec4 = Emb->getFunctionVector();
+  EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
 }
 
 TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -426,6 +430,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
   EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
   EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
+
+  Emb->invalidateEmbeddings();
+  const auto &FuncVec4 = Emb->getFunctionVector();
+  EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
 }
 
 static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;

>From eb0c96c4f88438e0530341a53e9c59361370b61e Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 9 Oct 2025 22:15:12 +0000
Subject: [PATCH 4/4] Fix error in test

---
 llvm/unittests/Analysis/IR2VecTest.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0ec1c11faf50f..8ffc5f61d5e55 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,9 @@ namespace {
 class TestableEmbedder : public Embedder {
 public:
   TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
-  Embedding computeEmbeddings(const Instruction &I) const override {}
+  Embedding computeEmbeddings(const Instruction &I) const override {
+    return Embedding();
+  }
 };
 
 TEST(EmbeddingTest, ConstructorsAndAccessors) {



More information about the llvm-commits mailing list