[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
@@ -71,20 +71,43 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding
&Out,
// Embedding
//===--===//
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+ Embedding Result(*this);
+ std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+ std::plus());
+ return Result;
+}
+
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
svkeerthy wrote:
Implemented in terms of = as it would avoid copies in
=.
https://github.com/llvm/llvm-project/pull/145118
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/svkeerthy updated
https://github.com/llvm/llvm-project/pull/145118
>From 10019cae162bb53e147797b655da75aac33b0a20 Mon Sep 17 00:00:00 2001
From: svkeerthy
Date: Fri, 20 Jun 2025 23:00:40 +
Subject: [PATCH] Overloading operator+ for Embeddngs
---
llvm/include/llvm/Analysis/IR2Vec.h| 9 --
llvm/lib/Analysis/IR2Vec.cpp | 19 -
llvm/unittests/Analysis/IR2VecTest.cpp | 39 ++
3 files changed, 63 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h
b/llvm/include/llvm/Analysis/IR2Vec.h
index 040cb84ff27a1..ef8f630d7feb1 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -107,9 +107,12 @@ struct Embedding {
const std::vector &getData() const { return Data; }
/// Arithmetic operators
- Embedding &operator+=(const Embedding &RHS);
- Embedding &operator-=(const Embedding &RHS);
- Embedding &operator*=(double Factor);
+ LLVM_ABI Embedding &operator+=(const Embedding &RHS);
+ LLVM_ABI Embedding operator+(const Embedding &RHS) const;
+ LLVM_ABI Embedding &operator-=(const Embedding &RHS);
+ LLVM_ABI Embedding operator-(const Embedding &RHS) const;
+ LLVM_ABI Embedding &operator*=(double Factor);
+ LLVM_ABI Embedding operator*(double Factor) const;
/// Adds Src Embedding scaled by Factor with the called Embedding.
/// Called_Embedding += Src * Factor
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 895b3de58a54e..bf456102bb618 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -70,7 +70,6 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding
&Out,
//
==--===//
// Embedding
//===--===//
-
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -78,6 +77,12 @@ Embedding &Embedding::operator+=(const Embedding &RHS) {
return *this;
}
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ Embedding Result(*this);
+ Result += RHS;
+ return Result;
+}
+
Embedding &Embedding::operator-=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -85,12 +90,24 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
return *this;
}
+Embedding Embedding::operator-(const Embedding &RHS) const {
+ Embedding Result(*this);
+ Result -= RHS;
+ return Result;
+}
+
Embedding &Embedding::operator*=(double Factor) {
std::transform(this->begin(), this->end(), this->begin(),
[Factor](double Elem) { return Elem * Factor; });
return *this;
}
+Embedding Embedding::operator*(double Factor) const {
+ Embedding Result(*this);
+ Result *= Factor;
+ return Result;
+}
+
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp
b/llvm/unittests/Analysis/IR2VecTest.cpp
index 3c97c20ae72d5..70d4808dc6d54 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
}
}
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 + E2;
+ EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) {
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
+TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 - E2;
+ EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, SubtractVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) {
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
}
+TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = E1 * 0.5f;
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
+
+ // Check that E1 is unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+}
+
TEST(EmbeddingTest, AddS
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/mtrofin approved this pull request. https://github.com/llvm/llvm-project/pull/145118 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/mtrofin edited https://github.com/llvm/llvm-project/pull/145118 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/svkeerthy edited https://github.com/llvm/llvm-project/pull/145118 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/svkeerthy updated
https://github.com/llvm/llvm-project/pull/145118
>From 10019cae162bb53e147797b655da75aac33b0a20 Mon Sep 17 00:00:00 2001
From: svkeerthy
Date: Fri, 20 Jun 2025 23:00:40 +
Subject: [PATCH] Overloading operator+ for Embeddngs
---
llvm/include/llvm/Analysis/IR2Vec.h| 9 --
llvm/lib/Analysis/IR2Vec.cpp | 19 -
llvm/unittests/Analysis/IR2VecTest.cpp | 39 ++
3 files changed, 63 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h
b/llvm/include/llvm/Analysis/IR2Vec.h
index 040cb84ff27a1..ef8f630d7feb1 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -107,9 +107,12 @@ struct Embedding {
const std::vector &getData() const { return Data; }
/// Arithmetic operators
- Embedding &operator+=(const Embedding &RHS);
- Embedding &operator-=(const Embedding &RHS);
- Embedding &operator*=(double Factor);
+ LLVM_ABI Embedding &operator+=(const Embedding &RHS);
+ LLVM_ABI Embedding operator+(const Embedding &RHS) const;
+ LLVM_ABI Embedding &operator-=(const Embedding &RHS);
+ LLVM_ABI Embedding operator-(const Embedding &RHS) const;
+ LLVM_ABI Embedding &operator*=(double Factor);
+ LLVM_ABI Embedding operator*(double Factor) const;
/// Adds Src Embedding scaled by Factor with the called Embedding.
/// Called_Embedding += Src * Factor
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 895b3de58a54e..bf456102bb618 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -70,7 +70,6 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding
&Out,
//
==--===//
// Embedding
//===--===//
-
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -78,6 +77,12 @@ Embedding &Embedding::operator+=(const Embedding &RHS) {
return *this;
}
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ Embedding Result(*this);
+ Result += RHS;
+ return Result;
+}
+
Embedding &Embedding::operator-=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -85,12 +90,24 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
return *this;
}
+Embedding Embedding::operator-(const Embedding &RHS) const {
+ Embedding Result(*this);
+ Result -= RHS;
+ return Result;
+}
+
Embedding &Embedding::operator*=(double Factor) {
std::transform(this->begin(), this->end(), this->begin(),
[Factor](double Elem) { return Elem * Factor; });
return *this;
}
+Embedding Embedding::operator*(double Factor) const {
+ Embedding Result(*this);
+ Result *= Factor;
+ return Result;
+}
+
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp
b/llvm/unittests/Analysis/IR2VecTest.cpp
index 3c97c20ae72d5..70d4808dc6d54 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
}
}
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 + E2;
+ EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) {
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
+TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 - E2;
+ EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, SubtractVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) {
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
}
+TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = E1 * 0.5f;
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
+
+ // Check that E1 is unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+}
+
TEST(EmbeddingTest, AddS
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/mtrofin edited https://github.com/llvm/llvm-project/pull/145118 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
@@ -71,20 +71,43 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding
&Out,
// Embedding
//===--===//
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+ Embedding Result(*this);
+ std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+ std::plus());
+ return Result;
+}
+
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
mtrofin wrote:
can you implement the = variants in terms of the ?
https://github.com/llvm/llvm-project/pull/145118
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [IR2Vec] Add out-of-place arithmetic operators to Embedding class (PR #145118)
https://github.com/svkeerthy edited https://github.com/llvm/llvm-project/pull/145118 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
