This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e97f41085d [SYSTEMDS-3851] Builtin for embedding Layer
e97f41085d is described below
commit e97f41085dbb2be3e1ffb8fd259e3935feabf3ac
Author: MaximilianSchreff <[email protected]>
AuthorDate: Sun Apr 6 22:22:57 2025 +0200
[SYSTEMDS-3851] Builtin for embedding Layer
This patch adds the embedding layer as a built-in operator in our nn/layers
library.
The functionality is similar to pytorch.nn.Embedding
(https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)
The layer receives indices as input which refer to indices of an embedding
dictionary and returns an embedding matrix where row i refers to embedding
vector indices[i] of the embedding dictionary.
This layer is used in every transformer architecture. Here the indices
usually come from a tokenizer and the embedding matrix is the input
to the actual transformer model.
Closes #2237
---
scripts/nn/layers/embedding.dml | 99 +++++++++++++++
.../test/applications/nn/NNComponentTest.java | 5 +
.../applications/nn/component/embedding.dml | 138 +++++++++++++++++++++
3 files changed, 242 insertions(+)
diff --git a/scripts/nn/layers/embedding.dml b/scripts/nn/layers/embedding.dml
new file mode 100644
index 0000000000..22afb39a74
--- /dev/null
+++ b/scripts/nn/layers/embedding.dml
@@ -0,0 +1,99 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+forward = function(matrix[double] indices, matrix[double] embedding_dict)
+ return (matrix[double] embeddings) {
+ /*
+ * Forward pass of an embedding layer. An embedding matrix is constructed
+ * from indices and corresponding embedding vectors from the embedding
+ * dictionary.
+ *
+ * Inputs:
+ * - indices: Indices referring to embedding vectors of embedding dictionary
+ * of shape n x 1 with each value in {1, ..., v}.
+ * - embedding_dict: Dictionary of embedding vectors of shape v x d.
+ *
+ * Outputs:
+ * - embeddings: Embedding matrix where row i is equal to
+ * embedding_dict[indices[i]].
+ */
+ n = nrow(indices)
+ v = nrow(embedding_dict)
+
+ # Construct permutation-like matrix (one '1' per row, rest '0')
+ permutation = matrix(0, rows=n, cols=v)
+ for (i in 1:n) {
+ permutation[i, as.integer(as.scalar(indices[i]))] = 1
+ }
+
+ embeddings = permutation %*% embedding_dict
+}
+
+backward = function(matrix[double] dout, matrix[double] indices, int v,
+ int padding_idx = -1)
+ return (matrix[double] dembedding_dict) {
+ /*
+ * Backward pass of embedding layer computes the gradients of the embedding
+ * dictionary.
+ *
+ * Inputs:
+ * - dout: Gradient of the output.
+ * - indices: Indices referring to embedding vectors of embedding dictionary
+ * of shape n x 1 with each value in {1, ..., v}.
+ * - v: Embedding dictionary size.
+ * - padding_idx: Index of embedding vector of embedding dictionary which
+ * should not be updated (i.e. gradients are 0). Use -1 if
+ * there is no padding vector.
+ *
+ * Outputs:
+ * - dembedding_dict: Gradients of the dictionary of embedding vectors of
+ * shape v x d.
+ */
+ n = nrow(indices)
+
+ # Construct permutation-like matrix (one '1' per row, rest '0')
+ permutation = matrix(0, rows=n, cols=v)
+ for (i in 1:n) {
+ permutation[i, as.integer(as.scalar(indices[i]))] = 1
+ }
+
+ dembedding_dict = t(permutation) %*% dout
+ if (padding_idx != -1) {
+ dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
+ }
+}
+
+init = function(int v, int d, int seed = -1)
+ return (matrix[double] embedding_dict) {
+ /*
+ * Initializes embedding dictionary matrix via N(0, 1).
+ *
+ * Inputs:
+ * - v: Embedding dictionary size.
+ * - d: Embedding vector dimension.
+ * - seed: Random generation seed.
+ *
+ * Output:
+ * - embedding_dict: Embedding dictionary matrix of shape v x d.
+ */
+ embedding_dict = rand(rows=v, cols=d, pdf="normal", seed=seed)
+}
+
diff --git
a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
index a9922cf35f..55d322a9b3 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
@@ -129,6 +129,11 @@ public class NNComponentTest extends TestFolder {
run("gelu.dml");
}
+ @Test
+ public void embedding() {
+ run("embedding.dml");
+ }
+
@Override
protected void run(String name) {
super.run("component/" + name);
diff --git a/src/test/scripts/applications/nn/component/embedding.dml
b/src/test/scripts/applications/nn/component/embedding.dml
new file mode 100644
index 0000000000..e5b7f82be5
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/embedding.dml
@@ -0,0 +1,138 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("nn/layers/embedding.dml") as embedding
+source("src/test/scripts/applications/nn/util.dml") as test_util
+
+embedding_test_forward = function() {
+ print("Testing Embedding - Forward Test")
+ n = 4
+ v = 7
+ d = 3
+
+ embedding_dict = matrix("-0.78327566 -0.87246466 -0.80580276
+ -0.17845497 2.1740944 -1.2514428
+ -0.27202556 -1.3681601 -1.5384313
+ 1.4215976 -0.463162 1.2592019
+ -1.7417 -0.46109396 -0.06011621
+ -0.7803316 1.0802858 0.7465289
+ 0. 0. 0.", rows=v, cols=d)
+ indices = matrix("1 6 7 6", rows=n, cols=1)
+
+ embeddings = embedding::forward(indices, embedding_dict)
+
+ expected_embeddings = matrix("-0.78327566 -0.87246466 -0.80580276
+ -0.7803316 1.0802858 0.7465289
+ 0. 0. 0.
+ -0.7803316 1.0802858 0.7465289", rows=n, cols=d)
+
+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+}
+
+embedding_test_forward_backward_no_pad = function() {
+ print("Testing Embedding - Forward & Backward Test w/out Padding")
+ n = 2
+ v = 4
+ d = 3
+
+ embedding_dict = matrix("-0.15039968 0.56168836 -0.577436
+ 0.47334725 1.5215642 -0.1924941
+ 1.600819 -1.1331359 -2.58817
+ 0.9779929 -0.82212716 -1.5917081", rows=v, cols=d)
+ indices = matrix("2 3", rows=n, cols=1)
+
+ embeddings = embedding::forward(indices, embedding_dict)
+
+ expected_embeddings = matrix("0.47334725 1.5215642 -0.1924941
+ 1.600819 -1.1331359 -2.58817", rows=n, cols=d)
+
+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+
+ dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
+ padding_idx = -1
+
+ dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
+ expected_dembedding_dict = matrix("0. 0. 0.
+ 1. 2. 3.
+ 4. 5. 6.
+ 0. 0. 0.", rows=v, cols=d)
+ test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
+}
+
+embedding_test_forward_backward_pad = function() {
+ print("Testing Embedding - Forward & Backward Test w/ Padding")
+ n = 5
+ v = 10
+ d = 6
+
+ embedding_dict = matrix("-1.24377859e+00 -1.10724878e+00 2.35533118e-01
6.65530920e-01
+ 9.80555452e-03 6.31030917e-01
+ 8.16493928e-01 -6.21011078e-01 -5.75569510e-01 -3.93419750e-02
+ -6.20878041e-01 1.37852756e-02
+ 7.43950903e-01 1.60437262e+00 -2.31788456e-01 1.15943216e-01
+ -8.83608997e-01 1.11547875e+00
+ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
+ 0.00000000e+00 0.00000000e+00
+ 1.70598769e+00 1.82770026e+00 1.30581510e+00 1.05738208e-01
+ 4.50116873e-01 3.48498315e-01
+ 1.40551448e+00 3.43091488e-02 1.84714049e-03 -5.52828193e-01
+ 3.65064174e-01 -9.31223869e-01
+ 1.33713937e+00 -3.43729639e+00 -1.22915792e+00 -1.12923630e-01
+ -1.16292477e+00 -2.16708351e-02
+ 6.63879395e-01 -2.76697308e-01 -9.02738094e-01 -6.85515344e-01
+ -6.43863618e-01 -2.30419707e+00
+ 1.44121364e-01 5.20578504e-01 -6.53087497e-01 6.62900746e-01
+ 3.82369667e-01 -2.25386508e-02
+ 2.20637798e+00 -6.86733365e-01 -1.27398467e+00 6.28316283e-01
+ 2.70236313e-01 2.20882833e-01", rows=v, cols=d)
+ indices = matrix("1 1 1 4 6", rows=n, cols=1)
+
+ embeddings = embedding::forward(indices, embedding_dict)
+
+ expected_embeddings = matrix("-1.2437786 -1.1072488 0.23553312 0.6655309
0.00980555 0.6310309
+ -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
+ -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
+ 0. 0. 0. 0. 0. 0.
+ 1.4055145 0.03430915 0.00184714 -0.5528282 0.36506417 -0.93122387",
rows=n, cols=d)
+
+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+
+ dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
+ padding_idx = 4
+
+ dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
+ expected_dembedding_dict = matrix("21. 24. 27. 30. 33. 36.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.
+ 25. 26. 27. 28. 29. 30.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.
+ 0. 0. 0. 0. 0. 0.", rows=v, cols=d)
+ test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
+}
+
+embedding_test_forward()
+embedding_test_forward_backward_no_pad()
+embedding_test_forward_backward_pad()
+