[llvm] [IR2Vec] Exposing Embedding as ADT wrapped around std::vector<double> (PR #143197)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 6 13:16:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

<details>
<summary>Changes</summary>

Currently `Embedding` is `std::vector<double>`. This PR makes it an ADT wrapped around `std::vector<double>` to overload basic arithmetic operators and expose comparison operations. It _simplifies_ the usage here and in the passes where operations on `Embedding` would be performed.

---
Full diff: https://github.com/llvm/llvm-project/pull/143197.diff


3 Files Affected:

- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+18-8) 
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+39-19) 
- (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+54-14) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9fd1b0ae8e248..930b13f079796 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -53,7 +53,24 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
-using Embedding = std::vector<double>;
+/// Embedding is a ADT that wraps std::vector<double>. It provides
+/// additional functionality for arithmetic and comparison operations.
+struct Embedding : public std::vector<double> {
+  using std::vector<double>::vector;
+
+  /// Arithmetic operators
+  Embedding &operator+=(const Embedding &RHS);
+  Embedding &operator-=(const Embedding &RHS);
+
+  /// Adds Src Embedding scaled by Factor with the called Embedding.
+  /// Called_Embedding += Src * Factor
+  void scaleAndAdd(const Embedding &Src, float Factor);
+
+  /// Returns true if the embedding is approximately equal to the RHS embedding
+  /// within the specified tolerance.
+  bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
+};
+
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 // FIXME: Current the keys are strings. This can be changed to
@@ -99,13 +116,6 @@ class Embedder {
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
 
-  /// Adds two vectors: Dst += Src
-  static void addVectors(Embedding &Dst, const Embedding &Src);
-
-  /// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor);
-
 public:
   virtual ~Embedder() = default;
 
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 490db5fdcdf99..8ee8e5b0ff74e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -55,6 +55,40 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
+// ==----------------------------------------------------------------------===//
+// Embedding
+//===----------------------------------------------------------------------===//
+
+Embedding &Embedding::operator+=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::plus<double>());
+  return *this;
+}
+
+Embedding &Embedding::operator-=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::minus<double>());
+  return *this;
+}
+
+void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
+  assert(this->size() == Src.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i) {
+    (*this)[i] += Src[i] * Factor;
+  }
+}
+
+bool Embedding::approximatelyEquals(const Embedding &RHS,
+                                    double Tolerance) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i)
+    if (std::abs((*this)[i] - RHS[i]) > Tolerance)
+      return false;
+  return true;
+}
+
 // ==----------------------------------------------------------------------===//
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
@@ -73,20 +107,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
 
-void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
-                 std::plus<double>());
-}
-
-void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
-                               float Factor) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  for (size_t i = 0; i < Dst.size(); ++i) {
-    Dst[i] += Src[i] * Factor;
-  }
-}
-
 // FIXME: Currently lookups are string based. Use numeric Keys
 // for efficiency
 Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +184,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     Embedding InstVector(Dimension, 0);
 
     const auto OpcVec = lookupVocab(I.getOpcodeName());
-    addScaledVector(InstVector, OpcVec, OpcWeight);
+    InstVector.scaleAndAdd(OpcVec, OpcWeight);
 
     // FIXME: Currently lookups are string based. Use numeric Keys
     // for efficiency.
     const auto Type = I.getType();
     const auto TypeVec = getTypeEmbedding(Type);
-    addScaledVector(InstVector, TypeVec, TypeWeight);
+    InstVector.scaleAndAdd(TypeVec, TypeWeight);
 
     for (const auto &Op : I.operands()) {
       const auto OperandVec = getOperandEmbedding(Op.get());
-      addScaledVector(InstVector, OperandVec, ArgWeight);
+      InstVector.scaleAndAdd(OperandVec, ArgWeight);
     }
     InstVecMap[&I] = InstVector;
-    addVectors(BBVector, InstVector);
+    BBVector += InstVector;
   }
   BBVecMap[&BB] = BBVector;
 }
@@ -187,7 +207,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
     return;
   for (const auto &BB : F) {
     computeEmbeddings(BB);
-    addVectors(FuncVector, BBVecMap[&BB]);
+    FuncVector += BBVecMap[&BB];
   }
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9e47b2cd8bedd..7259a8a2fe20a 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -32,13 +32,6 @@ class TestableEmbedder : public Embedder {
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
-  static void addVectors(Embedding &Dst, const Embedding &Src) {
-    Embedder::addVectors(Dst, Src);
-  }
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor) {
-    Embedder::addScaledVector(Dst, Src, Factor);
-  }
 };
 
 TEST(IR2VecTest, CreateSymbolicEmbedder) {
@@ -79,37 +72,83 @@ TEST(IR2VecTest, AddVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
 
-  TestableEmbedder::addVectors(E1, E2);
+  E1 += E2;
   EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
+TEST(IR2VecTest, SubtractVectors) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  E1 -= E2;
+  EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
+
+  // Check that E2 is unchanged
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(IR2VecTest, AddScaledVector) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {2.0, 0.5, -1.0};
 
-  TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+  E1.scaleAndAdd(E2, 0.5f);
   EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
 }
 
+TEST(IR2VecTest, ApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
+  EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
+
+  Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E3));
+  EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
+
+  Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
+  EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
+
+  Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside));
+
+  Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
+  EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
+
+  Embedding E5 = {1.0, 2.0, 3.0};
+  EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
+  EXPECT_TRUE(E1.approximatelyEquals(E5));
+}
+
 #if GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
-               "Vectors must have the same dimension");
+  EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsSubtractVectors) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
 }
 
 TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+  EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
+               "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.010};
+  EXPECT_DEATH(E1.approximatelyEquals(E2),
                "Vectors must have the same dimension");
 }
 #endif // NDEBUG
@@ -136,8 +175,9 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
   Embedding E1;
   Embedding E2;
   // Should be no-op, but not crash
-  TestableEmbedder::addVectors(E1, E2);
-  TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+  E1 += E2;
+  E1 -= E2;
+  E1.scaleAndAdd(E2, 1.0f);
   EXPECT_TRUE(E1.empty());
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/143197


More information about the llvm-commits mailing list