This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 34a6571038 [SYSTEMDS-3803] DML-bodied Util Function for Transposing 
ABCD to ACBD
34a6571038 is described below

commit 34a6571038f88517b27d5f218a2618cb29e28c7a
Author: Maximilian.S <[email protected]>
AuthorDate: Thu Dec 5 15:29:22 2024 +0100

    [SYSTEMDS-3803] DML-bodied Util Function for Transposing ABCD to ACBD
    
    This patch adds a simple util function for transposing matrices in a
    specified way, which is required for multi-head attention implementation.
    
    Closes #2151
---
 scripts/nn/util.dml                                | 25 +++++++
 .../test/applications/nn/NNComponentTest.java      |  5 ++
 .../nn/component/transpose_ABCD_to_ACBD.dml        | 82 ++++++++++++++++++++++
 3 files changed, 112 insertions(+)

diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 807d7baad9..a3d0f84c5c 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -380,3 +380,28 @@ top_k2d = function(matrix[double] X, int k, int C, int 
Hin, int Win)
   indices = transpose_NCHW_to_CNHW(indices_K_NHW, N)
 }
 
+transpose_ABCD_to_ACBD = function(matrix[double] X, int B, int C)
+    return (matrix[double] out) {
+  /*
+   * Reshape util for tensors in ABCD format.
+   * Transposes the 2nd and 3rd axes.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (A, B*C*D).
+   *  - B: Dimension of 2nd axis.
+   *  - C: Dimension of 3rd axis.
+   *
+   * Outputs:
+   *  - out: Outputs with the 2nd and 3rd axes transposed, of
+   *      shape (A, C*B*D).
+   */
+  A = nrow(X)
+  BCD = ncol(X)
+
+  # use NCHW_to_CNHW for X: (A, B*C*D) -> (B, A*C*D)
+  X_BACD = transpose_NCHW_to_CNHW(X, B)
+  # use NCHW_to_CNHW for X: (B, A*C*D) -> (A*C, B*D)
+  X_ACBD = transpose_NCHW_to_CNHW(X_BACD, A*C)
+  # reshape X: (A*C, B*D) -> (A, C*B*D)
+  out = matrix(X_ACBD, rows=A, cols=BCD)
+}
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java 
b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
index 86b2f64bb7..3b002871d7 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
@@ -108,6 +108,11 @@ public class NNComponentTest extends TestFolder {
                run("transpose_NCHW_to_CNHW.dml");
        }
 
+       @Test
+       public void transpose_ABCD_to_ACBD() {
+               run("transpose_ABCD_to_ACBD.dml");
+       }
+
        @Test 
        public void logcosh(){
                run("logcosh.dml");
diff --git 
a/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml 
b/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml
new file mode 100644
index 0000000000..1fa49c9b87
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml
@@ -0,0 +1,82 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/applications/nn/util.dml") as test_util
+source("scripts/nn/util.dml") as util
+
+
+transpose_ABCD_to_ACBD = function() {
+  /*
+   * Test for `transpose_ABCD_to_ACBD` function.
+   */
+  print("Testing transpose_ABCD_to_ACBD function.")
+
+  # Generate data
+  A = 2
+  B = 3
+  C = 4
+  D = 5
+  X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
+
+  out = util::transpose_ABCD_to_ACBD(X, B, C)
+
+  target =
+    matrix("1 2 3 4 5 21 22 23 24 25 41 42 43 44 45
+            6 7 8 9 10 26 27 28 29 30 46 47 48 49 50
+            11 12 13 14 15 31 32 33 34 35 51 52 53 54 55
+            16 17 18 19 20 36 37 38 39 40 56 57 58 59 60
+
+            61 62 63 64 65 81 82 83 84 85 101 102 103 104 105
+            66 67 68 69 70 86 87 88 89 90 106 107 108 109 110
+            71 72 73 74 75 91 92 93 94 95 111 112 113 114 115
+            76 77 78 79 80 96 97 98 99 100 116 117 118 119 120",
+           rows=A, cols=C*B*D)
+
+  # Equivalency check
+  test_util::check_all_close(out, target, 1e-10)
+}
+
+
+transpose_ABCD_to_ACBD_single_val = function() {
+  /*
+   * Test for `transpose_ABCD_to_ACBD` function,
+   * transposing a single value matrix.
+   */
+  print("Testing transpose_ABCD_to_ACBD function with single value.")
+
+  # Generate data
+  A = 1
+  B = 1
+  C = 1
+  D = 1
+  X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
+
+  out = util::transpose_ABCD_to_ACBD(X, B, C)
+
+  target = X
+
+  # Equivalency check
+  test_util::check_all_close(out, target, 1e-10)
+}
+
+
+transpose_ABCD_to_ACBD()
+transpose_ABCD_to_ACBD_single_val()

Reply via email to