This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e97f41085d [SYSTEMDS-3851] Builtin for embedding Layer
e97f41085d is described below

commit e97f41085dbb2be3e1ffb8fd259e3935feabf3ac
Author: MaximilianSchreff <[email protected]>
AuthorDate: Sun Apr 6 22:22:57 2025 +0200

    [SYSTEMDS-3851] Builtin for embedding Layer
    
    This patch adds the embedding layer as a built-in operator in our nn/layers 
library.
    The functionality is similar to pytorch.nn.Embedding
    (https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)
    The layer receives indices as input which refer to indices of an embedding
    dictionary and returns an embedding matrix where row i refers to embedding
    vector indices[i] of the embedding dictionary.
    This layer is used in every transformer architecture. Here the indices
    usually come from a tokenizer and the embedding matrix is the input
    to the actual transformer model.
    
    Closes #2237
---
 scripts/nn/layers/embedding.dml                    |  99 +++++++++++++++
 .../test/applications/nn/NNComponentTest.java      |   5 +
 .../applications/nn/component/embedding.dml        | 138 +++++++++++++++++++++
 3 files changed, 242 insertions(+)

diff --git a/scripts/nn/layers/embedding.dml b/scripts/nn/layers/embedding.dml
new file mode 100644
index 0000000000..22afb39a74
--- /dev/null
+++ b/scripts/nn/layers/embedding.dml
@@ -0,0 +1,99 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+forward = function(matrix[double] indices, matrix[double] embedding_dict)
+    return (matrix[double] embeddings) {
+  /*
+   * Forward pass of an embedding layer. An embedding matrix is constructed
+   * from indices and corresponding embedding vectors from the embedding
+   * dictionary.
+   * 
+   * Inputs:
+   * - indices: Indices referring to embedding vectors of embedding dictionary
+   *            of shape n x 1 with each value in {1, ..., v}.
+   * - embedding_dict: Dictionary of embedding vectors of shape v x d.
+   *
+   * Outputs:
+   * - embeddings: Embedding matrix where row i is equal to 
+   *               embedding_dict[indices[i]].
+   */
+  n = nrow(indices)
+  v = nrow(embedding_dict)
+
+  # Construct permutation-like matrix (one '1' per row, rest '0')
+  permutation = matrix(0, rows=n, cols=v)
+  for (i in 1:n) {
+    permutation[i, as.integer(as.scalar(indices[i]))] = 1
+  }
+
+  embeddings = permutation %*% embedding_dict
+}
+
+backward = function(matrix[double] dout, matrix[double] indices, int v,
+      int padding_idx = -1)
+    return (matrix[double] dembedding_dict) {
+  /*
+   * Backward pass of embedding layer computes the gradients of the embedding
+   * dictionary.
+   *
+   * Inputs:
+   * - dout: Gradient of the output.
+   * - indices: Indices referring to embedding vectors of embedding dictionary
+   *            of shape n x 1 with each value in {1, ..., v}.
+   * - v: Embedding dictionary size.
+   * - padding_idx: Index of embedding vector of embedding dictionary which
+   *                should not be updated (i.e. gradients are 0). Use -1 if
+   *                there is no padding vector.
+   *
+   * Outputs:
+   * - dembedding_dict: Gradients of the dictionary of embedding vectors of 
+   *                    shape v x d.
+   */
+  n = nrow(indices)
+
+  # Construct permutation-like matrix (one '1' per row, rest '0')
+  permutation = matrix(0, rows=n, cols=v)
+  for (i in 1:n) {
+    permutation[i, as.integer(as.scalar(indices[i]))] = 1
+  }
+
+  dembedding_dict = t(permutation) %*% dout
+  if (padding_idx != -1) {
+    dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
+  }
+}
+
+init = function(int v, int d, int seed  = -1)
+    return (matrix[double] embedding_dict) {
+  /*
+   * Initializes embedding dictionary matrix via N(0, 1).
+   *
+   * Inputs:
+   * - v: Embedding dictionary size.
+   * - d: Embedding vector dimension.
+   * - seed: Random generation seed.
+   *
+   * Output:
+   * - embedding_dict: Embedding dictionary matrix of shape v x d.
+   */
+  embedding_dict = rand(rows=v, cols=d, pdf="normal", seed=seed)
+}
+
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java 
b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
index a9922cf35f..55d322a9b3 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
@@ -129,6 +129,11 @@ public class NNComponentTest extends TestFolder {
                run("gelu.dml");
        }
 
+       @Test
+       public void embedding() {
+               run("embedding.dml");
+       }
+
        @Override
        protected void run(String name) {
                super.run("component/" + name);
diff --git a/src/test/scripts/applications/nn/component/embedding.dml 
b/src/test/scripts/applications/nn/component/embedding.dml
new file mode 100644
index 0000000000..e5b7f82be5
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/embedding.dml
@@ -0,0 +1,138 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("nn/layers/embedding.dml") as embedding
+source("src/test/scripts/applications/nn/util.dml") as test_util
+
+embedding_test_forward = function() {
+  print("Testing Embedding - Forward Test")
+  n = 4
+  v = 7
+  d = 3
+
+  embedding_dict = matrix("-0.78327566 -0.87246466 -0.80580276
+    -0.17845497  2.1740944  -1.2514428
+    -0.27202556 -1.3681601  -1.5384313
+    1.4215976  -0.463162    1.2592019
+    -1.7417     -0.46109396 -0.06011621
+    -0.7803316   1.0802858   0.7465289
+    0.          0.          0.", rows=v, cols=d)
+  indices = matrix("1 6 7 6", rows=n, cols=1)
+
+  embeddings = embedding::forward(indices, embedding_dict)
+  
+  expected_embeddings = matrix("-0.78327566 -0.87246466 -0.80580276
+    -0.7803316   1.0802858   0.7465289
+    0.          0.          0.
+    -0.7803316   1.0802858   0.7465289", rows=n, cols=d)
+
+  test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+}
+
+embedding_test_forward_backward_no_pad = function() {
+  print("Testing Embedding - Forward & Backward Test w/out Padding")
+  n = 2
+  v = 4
+  d = 3
+
+  embedding_dict = matrix("-0.15039968  0.56168836 -0.577436
+   0.47334725  1.5215642  -0.1924941
+   1.600819   -1.1331359  -2.58817
+   0.9779929  -0.82212716 -1.5917081", rows=v, cols=d)
+  indices = matrix("2 3", rows=n, cols=1)
+
+  embeddings = embedding::forward(indices, embedding_dict)
+  
+  expected_embeddings = matrix("0.47334725  1.5215642  -0.1924941
+   1.600819   -1.1331359  -2.58817", rows=n, cols=d)
+
+  test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+
+  dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
+  padding_idx = -1
+
+  dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
+  expected_dembedding_dict = matrix("0. 0. 0.
+  1. 2. 3.
+  4. 5. 6.
+  0. 0. 0.", rows=v, cols=d)
+  test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
+}
+
+embedding_test_forward_backward_pad = function() {
+  print("Testing Embedding - Forward & Backward Test w/ Padding")
+  n = 5
+  v = 10
+  d = 6
+
+  embedding_dict = matrix("-1.24377859e+00 -1.10724878e+00  2.35533118e-01  
6.65530920e-01
+   9.80555452e-03  6.31030917e-01
+   8.16493928e-01 -6.21011078e-01 -5.75569510e-01 -3.93419750e-02
+  -6.20878041e-01  1.37852756e-02
+   7.43950903e-01  1.60437262e+00 -2.31788456e-01  1.15943216e-01
+  -8.83608997e-01  1.11547875e+00
+   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
+   0.00000000e+00  0.00000000e+00
+   1.70598769e+00  1.82770026e+00  1.30581510e+00  1.05738208e-01
+   4.50116873e-01  3.48498315e-01
+   1.40551448e+00  3.43091488e-02  1.84714049e-03 -5.52828193e-01
+   3.65064174e-01 -9.31223869e-01
+   1.33713937e+00 -3.43729639e+00 -1.22915792e+00 -1.12923630e-01
+  -1.16292477e+00 -2.16708351e-02
+   6.63879395e-01 -2.76697308e-01 -9.02738094e-01 -6.85515344e-01
+  -6.43863618e-01 -2.30419707e+00
+   1.44121364e-01  5.20578504e-01 -6.53087497e-01  6.62900746e-01
+   3.82369667e-01 -2.25386508e-02
+   2.20637798e+00 -6.86733365e-01 -1.27398467e+00  6.28316283e-01
+   2.70236313e-01  2.20882833e-01", rows=v, cols=d)
+  indices = matrix("1 1 1 4 6", rows=n, cols=1)
+
+  embeddings = embedding::forward(indices, embedding_dict)
+  
+  expected_embeddings = matrix("-1.2437786  -1.1072488   0.23553312  0.6655309 
  0.00980555  0.6310309
+  -1.2437786  -1.1072488   0.23553312  0.6655309   0.00980555  0.6310309
+  -1.2437786  -1.1072488   0.23553312  0.6655309   0.00980555  0.6310309
+   0.          0.          0.          0.          0.          0.
+   1.4055145   0.03430915  0.00184714 -0.5528282   0.36506417 -0.93122387", 
rows=n, cols=d)
+
+  test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
+
+  dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
+  padding_idx = 4
+
+  dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
+  expected_dembedding_dict = matrix("21. 24. 27. 30. 33. 36.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.
+  25. 26. 27. 28. 29. 30.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.
+   0.  0.  0.  0.  0.  0.", rows=v, cols=d)
+   test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
+}
+
+embedding_test_forward()
+embedding_test_forward_backward_no_pad()
+embedding_test_forward_backward_pad()
+

Reply via email to