[llvm-branch-commits] [llvm] Overloading operator+ for Embeddngs (PR #145118)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jun 20 16:30:39 PDT 2025
https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/145118
None
>From cbd2c6e77eefb4ba7b8acbf6ea12f21486e7dbc8 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Fri, 20 Jun 2025 23:00:40 +0000
Subject: [PATCH] Overloading operator+ for Embeddngs
---
llvm/include/llvm/Analysis/IR2Vec.h | 1 +
llvm/lib/Analysis/IR2Vec.cpp | 8 ++++++++
llvm/unittests/Analysis/IR2VecTest.cpp | 18 ++++++++++++++++++
3 files changed, 27 insertions(+)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 480b834077b86..f6c40d36f8026 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -106,6 +106,7 @@ struct Embedding {
const std::vector<double> &getData() const { return Data; }
/// Arithmetic operators
+ Embedding operator+(const Embedding &RHS) const;
Embedding &operator+=(const Embedding &RHS);
Embedding &operator-=(const Embedding &RHS);
Embedding &operator*=(double Factor);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 27cc2a4109879..d5d27db8bd2bf 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
// Embedding
//===----------------------------------------------------------------------===//
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+ Embedding Result(*this);
+ std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+ std::plus<double>());
+ return Result;
+}
+
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 33ac16828eb6c..50eb7f73c6f50 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
}
}
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 + E2;
+ EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -180,6 +192,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
}
+TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0};
+ Embedding E2 = {1.0};
+ EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
+}
+
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
More information about the llvm-branch-commits
mailing list