[llvm] [IR2Vec] Exposing Embedding as an ADT wrapped around std::vector<double> (PR #143197)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 13:42:17 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143197

>From 6817aa9606f68849515c92cb86693ca347acea55 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Fri, 6 Jun 2025 20:05:10 +0000
Subject: [PATCH] Embedding

---
 llvm/include/llvm/Analysis/IR2Vec.h    |  67 +++++++--
 llvm/lib/Analysis/IR2Vec.cpp           |  69 ++++++---
 llvm/unittests/Analysis/IR2VecTest.cpp | 192 +++++++++++++++++++------
 3 files changed, 254 insertions(+), 74 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9fd1b0ae8e248..14f28999b174c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -53,7 +53,61 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
-using Embedding = std::vector<double>;
+/// Embedding is a ADT that wraps std::vector<double>. It provides
+/// additional functionality for arithmetic and comparison operations.
+/// It is meant to be used *like* std::vector<double> but is more restrictive
+/// in the sense that it does not allow the user to change the size of the
+/// embedding vector. The dimension of the embedding is fixed at the time of
+/// construction of Embedding object. But the elements can be modified in-place.
+struct Embedding {
+private:
+  std::vector<double> Data;
+
+public:
+  Embedding() = default;
+  Embedding(const std::vector<double> &V) : Data(V) {}
+  Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
+  Embedding(std::initializer_list<double> IL) : Data(IL) {}
+  Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
+
+  size_t size() const { return Data.size(); }
+  bool empty() const { return Data.empty(); }
+
+  double &operator[](size_t Itr) {
+    assert(Itr < Data.size() && "Index out of bounds");
+    return Data[Itr];
+  }
+
+  const double &operator[](size_t Itr) const {
+    assert(Itr < Data.size() && "Index out of bounds");
+    return Data[Itr];
+  }
+
+  using iterator = typename std::vector<double>::iterator;
+  using const_iterator = typename std::vector<double>::const_iterator;
+
+  iterator begin() { return Data.begin(); }
+  iterator end() { return Data.end(); }
+  const_iterator begin() const { return Data.begin(); }
+  const_iterator end() const { return Data.end(); }
+  const_iterator cbegin() const { return Data.cbegin(); }
+  const_iterator cend() const { return Data.cend(); }
+
+  const std::vector<double> &getData() const { return Data; }
+
+  /// Arithmetic operators
+  Embedding &operator+=(const Embedding &RHS);
+  Embedding &operator-=(const Embedding &RHS);
+
+  /// Adds Src Embedding scaled by Factor with the called Embedding.
+  /// Called_Embedding += Src * Factor
+  Embedding &scaleAndAdd(const Embedding &Src, float Factor);
+
+  /// Returns true if the embedding is approximately equal to the RHS embedding
+  /// within the specified tolerance.
+  bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
+};
+
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 // FIXME: Current the keys are strings. This can be changed to
@@ -61,8 +115,8 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 using Vocab = std::map<std::string, Embedding>;
 
 /// Embedder provides the interface to generate embeddings (vector
-/// representations) for instructions, basic blocks, and functions. The vector
-/// representations are generated using IR2Vec algorithms.
+/// representations) for instructions, basic blocks, and functions. The
+/// vector representations are generated using IR2Vec algorithms.
 ///
 /// The Embedder class is an abstract class and it is intended to be
 /// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
@@ -99,13 +153,6 @@ class Embedder {
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
 
-  /// Adds two vectors: Dst += Src
-  static void addVectors(Embedding &Dst, const Embedding &Src);
-
-  /// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor);
-
 public:
   virtual ~Embedder() = default;
 
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 490db5fdcdf99..25ce35d4ace37 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -55,6 +55,51 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
+namespace llvm::json {
+inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
+                     llvm::json::Path P) {
+  std::vector<double> TempOut;
+  if (!llvm::json::fromJSON(E, TempOut, P))
+    return false;
+  Out = Embedding(std::move(TempOut));
+  return true;
+}
+} // namespace llvm::json
+
+// ==----------------------------------------------------------------------===//
+// Embedding
+//===----------------------------------------------------------------------===//
+
+Embedding &Embedding::operator+=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::plus<double>());
+  return *this;
+}
+
+Embedding &Embedding::operator-=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::minus<double>());
+  return *this;
+}
+
+Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
+  assert(this->size() == Src.size() && "Vectors must have the same dimension");
+  for (size_t Itr = 0; Itr < this->size(); ++Itr)
+    (*this)[Itr] += Src[Itr] * Factor;
+  return *this;
+}
+
+bool Embedding::approximatelyEquals(const Embedding &RHS,
+                                    double Tolerance) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  for (size_t Itr = 0; Itr < this->size(); ++Itr)
+    if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
+      return false;
+  return true;
+}
+
 // ==----------------------------------------------------------------------===//
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
@@ -73,20 +118,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
 
-void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
-                 std::plus<double>());
-}
-
-void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
-                               float Factor) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  for (size_t i = 0; i < Dst.size(); ++i) {
-    Dst[i] += Src[i] * Factor;
-  }
-}
-
 // FIXME: Currently lookups are string based. Use numeric Keys
 // for efficiency
 Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +195,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     Embedding InstVector(Dimension, 0);
 
     const auto OpcVec = lookupVocab(I.getOpcodeName());
-    addScaledVector(InstVector, OpcVec, OpcWeight);
+    InstVector.scaleAndAdd(OpcVec, OpcWeight);
 
     // FIXME: Currently lookups are string based. Use numeric Keys
     // for efficiency.
     const auto Type = I.getType();
     const auto TypeVec = getTypeEmbedding(Type);
-    addScaledVector(InstVector, TypeVec, TypeWeight);
+    InstVector.scaleAndAdd(TypeVec, TypeWeight);
 
     for (const auto &Op : I.operands()) {
       const auto OperandVec = getOperandEmbedding(Op.get());
-      addScaledVector(InstVector, OperandVec, ArgWeight);
+      InstVector.scaleAndAdd(OperandVec, ArgWeight);
     }
     InstVecMap[&I] = InstVector;
-    addVectors(BBVector, InstVector);
+    BBVector += InstVector;
   }
   BBVecMap[&BB] = BBVector;
 }
@@ -187,7 +218,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
     return;
   for (const auto &BB : F) {
     computeEmbeddings(BB);
-    addVectors(FuncVector, BBVecMap[&BB]);
+    FuncVector += BBVecMap[&BB];
   }
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9e47b2cd8bedd..46e9c71c58250 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -14,6 +14,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/JSON.h"
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
@@ -32,89 +33,189 @@ class TestableEmbedder : public Embedder {
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
-  static void addVectors(Embedding &Dst, const Embedding &Src) {
-    Embedder::addVectors(Dst, Src);
-  }
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor) {
-    Embedder::addScaledVector(Dst, Src, Factor);
-  }
 };
 
-TEST(IR2VecTest, CreateSymbolicEmbedder) {
-  Vocab V = {{"foo", {1.0, 2.0}}};
-
-  LLVMContext Ctx;
-  Module M("M", Ctx);
-  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
-  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+TEST(EmbeddingTest, ConstructorsAndAccessors) {
+  // Default constructor
+  Embedding E1;
+  EXPECT_TRUE(E1.empty());
+  EXPECT_EQ(E1.size(), 0u);
+
+  // Constructor with const std::vector<double>&
+  std::vector<double> Data = {1.0, 2.0, 3.0};
+  Embedding E2(Data);
+  EXPECT_FALSE(E2.empty());
+  EXPECT_EQ(E2.size(), 3u);
+  EXPECT_THAT(E2.getData(), ElementsAre(1.0, 2.0, 3.0));
+  EXPECT_EQ(E2[0], 1.0);
+  EXPECT_EQ(E2[1], 2.0);
+  EXPECT_EQ(E2[2], 3.0);
+
+  // Constructor with std::vector<double>&&
+  Embedding E3(std::vector<double>({4.0, 5.0}));
+  EXPECT_EQ(E3.size(), 2u);
+  EXPECT_THAT(E3.getData(), ElementsAre(4.0, 5.0));
+
+  // Constructor with std::initializer_list<double>
+  Embedding E4({6.0, 7.0, 8.0, 9.0});
+  EXPECT_EQ(E4.size(), 4u);
+  EXPECT_THAT(E4.getData(), ElementsAre(6.0, 7.0, 8.0, 9.0));
+  EXPECT_EQ(E4[0], 6.0);
+  E4[0] = 6.5;
+  EXPECT_EQ(E4[0], 6.5);
+
+  // Constructor with size_t and double
+  Embedding E5(5, 1.5);
+  EXPECT_EQ(E5.size(), 5u);
+  EXPECT_THAT(E5.getData(), ElementsAre(1.5, 1.5, 1.5, 1.5, 1.5));
+
+  // Test iterators
+  std::vector<double> VecE4;
+  for (double Val : E4) {
+    VecE4.push_back(Val);
+  }
+  EXPECT_THAT(VecE4, ElementsAre(6.5, 7.0, 8.0, 9.0));
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  EXPECT_TRUE(static_cast<bool>(Result));
+  const Embedding CE4 = E4;
+  std::vector<double> VecCE4;
+  for (const double &Val : CE4) {
+    VecCE4.push_back(Val);
+  }
+  EXPECT_THAT(VecCE4, ElementsAre(6.5, 7.0, 8.0, 9.0));
 
-  auto *Emb = Result->get();
-  EXPECT_NE(Emb, nullptr);
+  EXPECT_EQ(*E4.begin(), 6.5);
+  EXPECT_EQ(*(E4.end() - 1), 9.0);
+  EXPECT_EQ(*CE4.cbegin(), 6.5);
+  EXPECT_EQ(*(CE4.cend() - 1), 9.0);
 }
 
-TEST(IR2VecTest, CreateInvalidMode) {
-  Vocab V = {{"foo", {1.0, 2.0}}};
-
-  LLVMContext Ctx;
-  Module M("M", Ctx);
-  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
-  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+TEST(EmbeddingTest, AddVectors) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
 
-  // static_cast an invalid int to IR2VecKind
-  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
-  EXPECT_FALSE(static_cast<bool>(Result));
+  E1 += E2;
+  EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
 
-  std::string ErrMsg;
-  llvm::handleAllErrors(
-      Result.takeError(),
-      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
-  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+  // Check that E2 is unchanged
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
-TEST(IR2VecTest, AddVectors) {
+TEST(EmbeddingTest, SubtractVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
 
-  TestableEmbedder::addVectors(E1, E2);
-  EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
+  E1 -= E2;
+  EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
-TEST(IR2VecTest, AddScaledVector) {
+TEST(EmbeddingTest, AddScaledVector) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {2.0, 0.5, -1.0};
 
-  TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+  E1.scaleAndAdd(E2, 0.5f);
   EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
 }
 
+TEST(EmbeddingTest, ApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
+  EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
+
+  Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E3));
+  EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
+
+  Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
+  EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
+
+  Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside));
+
+  Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
+  EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
+
+  Embedding E5 = {1.0, 2.0, 3.0};
+  EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
+  EXPECT_TRUE(E1.approximatelyEquals(E5));
+}
+
 #if GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
-TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
+TEST(EmbeddingTest, AccessOutOfBounds) {
+  Embedding E = {1.0, 2.0, 3.0};
+  EXPECT_DEATH(E[3], "Index out of bounds");
+  EXPECT_DEATH(E[-1], "Index out of bounds");
+  EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
+}
+
+TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
-               "Vectors must have the same dimension");
+  EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
+}
+
+TEST(EmbeddingTest, MismatchedDimensionsSubtractVectors) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
 }
 
-TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
+TEST(EmbeddingTest, MismatchedDimensionsAddScaledVector) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+  EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
+               "Vectors must have the same dimension");
+}
+
+TEST(EmbeddingTest, MismatchedDimensionsApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.010};
+  EXPECT_DEATH(E1.approximatelyEquals(E2),
                "Vectors must have the same dimension");
 }
 #endif // NDEBUG
 #endif // GTEST_HAS_DEATH_TEST
 
+TEST(IR2VecTest, CreateSymbolicEmbedder) {
+  Vocab V = {{"foo", {1.0, 2.0}}};
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  EXPECT_TRUE(static_cast<bool>(Result));
+
+  auto *Emb = Result->get();
+  EXPECT_NE(Emb, nullptr);
+}
+
+TEST(IR2VecTest, CreateInvalidMode) {
+  Vocab V = {{"foo", {1.0, 2.0}}};
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  // static_cast an invalid int to IR2VecKind
+  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
+  EXPECT_FALSE(static_cast<bool>(Result));
+
+  std::string ErrMsg;
+  llvm::handleAllErrors(
+      Result.takeError(),
+      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
+  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+}
+
 TEST(IR2VecTest, LookupVocab) {
   Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
   LLVMContext Ctx;
@@ -136,8 +237,9 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
   Embedding E1;
   Embedding E2;
   // Should be no-op, but not crash
-  TestableEmbedder::addVectors(E1, E2);
-  TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+  E1 += E2;
+  E1 -= E2;
+  E1.scaleAndAdd(E2, 1.0f);
   EXPECT_TRUE(E1.empty());
 }
 



More information about the llvm-commits mailing list