This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 1f613276d6a0a38661319bfeb0f177fcdeb2883a Author: e-strauss <[email protected]> AuthorDate: Sat Jul 27 19:28:05 2024 +0200 [SYSTEMDS-3538] BI-LSTM forward/backward passes with test cases Closes #2051. --- scripts/nn/layers/bilstm.dml | 233 +++++++++++++++++++++ .../apache/sysds/test/functions/dnn/LSTMTest.java | 115 ++++++++-- .../expected/BILSTM_OUT_10.0_5.0_2.0_6.0.csv | 10 + .../expected/BILSTM_OUT_2.0_5.0_2.0_2.0.csv | 2 + .../expected/BILSTM_OUT_3.0_5.0_2.0_2.0.csv | 3 + .../expected/BILSTM_OUT_6.0_5.0_6.0_4.0.csv | 6 + .../expected/BILSTM_back_dW_1.0_5.0_6.0_4.0.csv | 22 ++ .../expected/BILSTM_back_dW_10.0_5.0_2.0_6.0.csv | 18 ++ .../expected/BILSTM_back_dW_4.0_5.0_6.0_4.0.csv | 22 ++ .../expected/BILSTM_back_dW_5.0_5.0_6.0_4.0.csv | 22 ++ .../expected/BILSTM_back_dW_6.0_5.0_6.0_4.0.csv | 22 ++ .../expected/BILSTM_back_dX_1.0_5.0_6.0_4.0.csv | 1 + .../expected/BILSTM_back_dX_10.0_5.0_2.0_6.0.csv | 10 + .../expected/BILSTM_back_dX_4.0_5.0_6.0_4.0.csv | 4 + .../expected/BILSTM_back_dX_5.0_5.0_6.0_4.0.csv | 5 + .../expected/BILSTM_back_dX_6.0_5.0_6.0_4.0.csv | 6 + .../expected/BILSTM_back_dc_1.0_5.0_6.0_4.0.csv | 4 + .../expected/BILSTM_back_dc_10.0_5.0_2.0_6.0.csv | 40 ++++ .../expected/BILSTM_back_dc_4.0_5.0_6.0_4.0.csv | 16 ++ .../expected/BILSTM_back_dc_5.0_5.0_6.0_4.0.csv | 20 ++ .../expected/BILSTM_back_dc_6.0_5.0_6.0_4.0.csv | 24 +++ .../functions/tensor/BILSTMBackwardTest.dml | 92 ++++++++ .../scripts/functions/tensor/BILSTMForwardTest.dml | 60 ++++++ 23 files changed, 742 insertions(+), 15 deletions(-) diff --git a/scripts/nn/layers/bilstm.dml b/scripts/nn/layers/bilstm.dml new file mode 100644 index 0000000000..38c427bf09 --- /dev/null +++ b/scripts/nn/layers/bilstm.dml @@ -0,0 +1,233 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# uncomment to use the lstm dml script instead of the built-in operator +# source("nn/layers/lstm.dml") as lstm + +/* + * Bidirectional-LSTM layer. + */ + +forward = function(matrix[double] X, matrix[double] W, matrix[double] W_reverse, matrix[double] b, + matrix[double] b_reverse, int T, int D, boolean seq, matrix[double] out0, matrix[double] c0) + return (matrix[double] out, matrix[double] c, + matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog) { + /* + * Computes the forward pass for an BI-LSTM layer. + * The input data has N sequences of T examples, each with D features. + * + * The BI-LSTM Layer processes the input tokens once in the normal order and once in the reverse order. + * The Layer uses different LSTM Cells for both passes. Therefore it contains 2*M neurons. The output of + * both passes is concatenated for each time step, which results in output of twice of the size of a standard + * LSTM Layer. + * The API is similiar to pytorch BI-LSTM, with the only difference that the input bias and hidden bias is combined + * to a single bias. The weights and biases for both directions are given separately (similiar to pytorch's API). + * It is possible to return only the output of the last cell (similiar to the normal LSTM layer). In that case the + * final output of both directions is concatenated. + * + * Reference: + * - Framewise phoneme classification with bidirectional LSTM networks; A. Graves, J. Schmidhuber; 2005 + * - https://ieeexplore.ieee.org/document/1556215 + * - Pytorch's LSTM / BILSTM + * - https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + * + * Inputs: + * - X: Inputs, of shape (N, T*D). + * - W: Weights, of shape (D+M, 4M). + * - W_reverse: Weights, of shape (D+M, 4M). + * - b: Biases, of shape (1, 4M). + * - b_reverse: Biases, of shape (1, 4M). + * - T: Length of example sequences (number of timesteps). + * - D: Dimensionality of the input features (number of features). + * - seq: Whether to return `out` at all timesteps, + * or just for the final timestep. + * - out0: Outputs from previous timestep for both passes, of shape (N*2, M). + * Note: This is *optional* and could just be an empty matrix. + * - c0: Initial cell state for both passes, of shape (N*2, M). + * Note: This is *optional* and could just be an empty matrix. + * + * Outputs: + * - out: If `return_sequences` is True, outputs for all timesteps, + * of shape (N, T*(M*2)). Else, outputs for the final timestep, of + * shape (N, M*2). + * - c: Cell state for final timestep, of shape (N*2, M). + * - cache_out: Cache of outputs, of shape (T*2, N*M). + * Note: This is used for performance during training. + * - cache_c: Cache of cell state, of shape (T*2, N*M). + * Note: This is used for performance during training. + * - cache_ifog: Cache of intermediate values, of shape (T*2, N*4M). + * Note: This is used for performance during training. + */ + + M = ncol(out0) + N = nrow(X) + + out0_forward = out0[1 : N,] + out0_reverse = out0[N + 1 : 2*N,] + c0_forward = c0[1 : N,] + c0_reverse = c0[N + 1 : 2*N,] + + # normal lstm pass + [out, c, cache_out, cache_c, cache_ifog] = lstm(X, W, b, out0_forward, c0_forward, seq) + #[out, c, cache_out, cache_c, cache_ifog] = lstm::forward(X, W, b, T,D, seq, out0_forward, c0_forward) + + # approach 1: reorder token by reversing X and weights of X + X_reverse = t(rev(t(X))) # reverse the elements inside a row + W_reverse[1:D,] = rev(W_reverse[1:D,]) # have to reverse the input weights as well + + # approach 2 (slower): reorder tokens by slicing + # X_reverse = matrix(0, rows=nrow(X), cols=ncol(X)) + # for (i in 1:T){X_reverse[,(T - i)*D+1:(T - i + 1)*D] = X[,(i-1)*D+1:i*D]} + + [out_reverse, c_reverse, cache_out_reverse, cache_c_reverse, cache_ifog_reverse] = lstm(X_reverse, W_reverse, b_reverse, out0_reverse, c0_reverse, seq) + #[out_reverse, c_reverse, cache_out_reverse, cache_c_reverse, cache_ifog_reverse] = lstm::forward(X_reverse, W_reverse, b_reverse, T,D, seq, out0_reverse, c0_reverse) + + # reorder the output of reverse lstm cell + if(seq){ + out_reversed = matrix(0, nrow(out_reverse), ncol(out_reverse)) + for (i in 1:T){out_reversed[,(T - i)*M+1:(T - i + 1)*M] = out_reverse[,(i-1)*M+1:i*M]} + } else { + out_reversed = out_reverse + } + + cache_out = rbind(cache_out, cache_out_reverse) + cache_c = rbind(cache_c, cache_c_reverse) + cache_ifog = rbind(cache_ifog, cache_ifog_reverse) + + c = rbind(c, c_reverse) + + if(seq == FALSE){T=1} + out = matrix(out, rows = N*T, cols=M) + out_reversed = matrix(out_reversed, rows = N*T, cols=M) + out = cbind(out, out_reversed) + out = matrix(out, rows = N, cols = T*2*M) +} + +backward = function(matrix[double] dout, matrix[double] dc, + matrix[double] X, matrix[double] W, matrix[double] W_reverse, matrix[double] b, + matrix[double] b_reverse, int T, int D, boolean seq, matrix[double] out0, matrix[double] c0, + matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog) + return (matrix[double] dX, matrix[double] dW, matrix[double] db, matrix[double] dW_reverse, + matrix[double] db_reverse, matrix[double] dout0, matrix[double] dc0) { + /* + * Computes the backward pass for an BI-LSTM layer with M neurons. + * + * Inputs: + * - dout: Gradient wrt `out`. If `seq` is `True`, + * contains gradients on outputs for all timesteps, of + * shape (N, T*(2*M)). Else, contains the gradient on the output + * for the final timestep, of shape (N, 2*M). + * - dc: Gradient wrt `c` (from later in time), of shape (N*2, M). + * This would come from later in time if the cell state was used + * downstream as the initial cell state for another LSTM layer. + * Typically, this would be used when a sequence was cut at + * timestep `T` and then continued in the next batch. If `c` + * was not used downstream, then `dc` would be an empty matrix. + * - X: Inputs, of shape (N, T*D). + * - W: Weights, of shape (D+M, 4M). + * - W_reverse: Weights, of shape (D+M, 4M). + * - b: Biases, of shape (1, 4M). + * - b_reverse: Biases, of shape (1, 4M). + * - T: Length of example sequences (number of timesteps). + * - D: Dimensionality of the input features. + * - seq: Whether `dout` is for all timesteps, + * or just for the final timestep. This is based on whether + * `return_sequences` was true in the forward pass. + * - out0: Outputs from previous timestep, of shape (N*2, M). + * Note: This is *optional* and could just be an empty matrix. + * - c0: Initial cell state, of shape (N*2, M). + * Note: This is *optional* and could just be an empty matrix. + * - cache_out: Cache of outputs, of shape (T*2, N*M). + * Note: This is used for performance during training. + * - cache_c: Cache of cell state, of shape (T*2, N*M). + * Note: This is used for performance during training. + * - cache_ifog: Cache of intermediate values, of shape (T*2, N*4*M). + * Note: This is used for performance during training. + * + * Outputs: + * - dX: Gradient wrt `X`, of shape (N, T*D). + * - dW: Gradient wrt `W`, of shape (D+M, 4M). + * - dW_reverse: Gradient wrt `W_reverse`, of shape (D+M, 4M). + * - db: Gradient wrt `b`, of shape (1, 4M). + * - db_reverse: Gradient wrt `b_reverse`, of shape (1, 4M). + * - dout0: Gradient wrt `out0`, of shape (N*2, M). + * - dc0: Gradient wrt `c0`, of shape (N*2, M). + */ + M = ncol(out0) + N = nrow(X) + + if(seq){dout = matrix(dout, rows=N*T, cols=2*M)} + dout_forward = dout[,1:M] + dout_reverse = dout[,M + 1 : 2*M] + if(seq){ + dout_forward = matrix(dout_forward, rows=N, cols=T*M) + dout_reverse = matrix(dout_reverse, rows=N, cols=T*M) + dout_reversed = matrix(0, nrow(dout_reverse), ncol(dout_reverse)) + for (i in 1:T){ + dout_reversed[,(T - i)*M+1:(T - i + 1)*M] = dout_reverse[,(i-1)*M+1:i*M] + } + } + else{ + dout_reversed = dout_reverse + } + dc_forward = dc[1:N,] + cache_out_forward = cache_out[1:T,] + cache_c_forward = cache_c[1:T,] + cache_ifog_forward = cache_ifog[1:T,] + out0_forward = out0[1 : nrow(out0) /2, ] + c0_forward = c0[1 : N, ] + + [dx, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0_forward, c0_forward, seq, dout_forward, dc_forward, cache_out_forward, cache_c_forward, cache_ifog_forward) + #[dx, dW, db, dout0, dc0] = lstm::backward(dout_forward, dc_forward, X, W,b,T,D,seq,out0_forward, c0_forward,cache_out_forward, cache_c_forward, cache_ifog_forward) + + # for approach1: + X_reverse = t(rev(t(X))) # reverse the elements inside a row + W_reverse[1:D,] = rev(W_reverse[1:D,]) # have to reverse the input weights as well + + # for approach2 (also slower for backward pass): + #X_reverse = X # matrix(0, rows=nrow(X), cols=ncol(X)) + #for (i in 1:T){X_reverse[,(T - i)*D+1:(T - i + 1)*D] = X[,(i-1)*D+1:i*D]} + + dc_reverse = dc[N+1 : 2*N,] + cache_out_reverse = cache_out[T+1 : 2*T,] + cache_c_reverse = cache_c[T+1 : 2*T,] + cache_ifog_reverse = cache_ifog[T+1 : 2*T,] + c0_reverse = c0[N + 1 : 2*N,] + out0_reverse = out0[N + 1 : 2*N,] + + [dx_reverse, dW_reverse, db_reverse, dout0_reverse, dc0_reverse] = lstm_backward(X_reverse, W_reverse, b_reverse, out0_reverse, c0_reverse, seq, dout_reversed, dc_reverse, cache_out_reverse, cache_c_reverse, cache_ifog_reverse) + #[dx_reverse, dW_reverse, db_reverse, dout0_reverse, dc0_reverse] = lstm::backward(dout_reversed, dc_reverse, X_reverse, W_reverse,b_reverse,T,D,seq,out0_reverse, c0_reverse,cache_out_reverse, cache_c_reverse, cache_ifog_reverse) + + + + # for approach1: + dx_reverse = t(rev(t(dx_reverse))) + dW_reverse[1:D,] = rev(dW_reverse[1:D,]) + + # for approach2: + #dx_reverse2 = matrix(0, rows=nrow(dx_reverse), cols=ncol(dx_reverse)) + #for (i in 1:T){dx_reverse2[,(T - i)*D+1:(T - i + 1)*D] = dx_reverse[,(i-1)*D+1:i*D]} + #dx_reverse = dx_reverse2 + + dX = dx + dx_reverse + dout0 = rbind(dout0, dout0_reverse) + dc0 = rbind(dc0, dc0_reverse) #matrix(1, rows=1, cols=1) +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java b/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java index bf19abd8db..481bdb1cc2 100644 --- a/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java @@ -27,10 +27,13 @@ import org.junit.Ignore; import org.junit.Test; import java.util.HashMap; +import java.util.Objects; public class LSTMTest extends AutomatedTestBase { String TEST_NAME1 = "LSTMForwardTest"; String TEST_NAME2 = "LSTMBackwardTest"; + String TEST_NAME3 = "BILSTMForwardTest"; + String TEST_NAME4 = "BILSTMBackwardTest"; private final static String TEST_DIR = "functions/tensor/"; @Override @@ -38,6 +41,8 @@ public class LSTMTest extends AutomatedTestBase { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, TEST_NAME1)); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, TEST_NAME2)); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, TEST_NAME3)); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_DIR, TEST_NAME4)); } @Test @@ -113,6 +118,60 @@ public class LSTMTest extends AutomatedTestBase { runLSTMTest(128, 128, 128,64, 0, 0, 1e-5, TEST_NAME2, true); } + // The BILSTM output is compared to output of pytorch's BI-LSTM Layer implementation with FP64. + // Expected results are saved at: "src/test/resources/expected/BILSTM_OUT_{batch_size}_{seq_length}_{num_features}_{hidden_size}.csv" + @Test + public void testBILSTMForwardLocal1(){ + runLSTMTest(3, 5, 2,2, 0, 1, 1e-5, TEST_NAME3,false); + } + + @Test + public void testBILSTMForwardLocal2(){ + runLSTMTest(6, 5, 6,4, 0, 1, 1e-5, TEST_NAME3,false); + } + + @Test + public void testBILSTMForwardLocal3(){ + runLSTMTest(10, 5, 2,6, 0, 1, 1e-5, TEST_NAME3,false); + } + + @Test + public void testBILSTMForwardLocal4(){ + runLSTMTest(3, 5, 2,2, 0, 0, 1e-5, TEST_NAME3,false); + } + + @Test + public void testBILSTMForwardLocal5(){ + runLSTMTest(2, 5, 2,2, 0, 1, 1e-5, TEST_NAME3,false); + } + + + @Test + public void testBILSTMBackwardLocal1(){ + runLSTMTest(10, 5, 2,6, 0, 1, 1e-5, TEST_NAME4,false); + } + + @Test + public void testBILSTMFBackwardLocal2(){ + runLSTMTest(1, 5, 6,4, 0, 1, 1e-5, TEST_NAME4,false); + } + + @Test + public void testBILSTMBackwardLocal3(){ + runLSTMTest(6, 5, 6,4, 0, 1, 1e-5, TEST_NAME4,false); + } + + @Test + public void testBILSTMBackwardLocal4(){ + runLSTMTest(5, 5, 6,4, 0, 1, 1e-5, TEST_NAME4,false); + } + + @Test + public void testBILSTMBackwardLocal5(){ + runLSTMTest(4, 5, 6,4, 0, 1, 1e-5, TEST_NAME4,false); + } + + private void runLSTMTest(double batch_size, double seq_length, double num_features, double hidden_size, String testname){ runLSTMTest(batch_size, seq_length, num_features, hidden_size,0, testname); } @@ -132,26 +191,52 @@ public class LSTMTest extends AutomatedTestBase { //run script //"-explain", "runtime", - programArgs = new String[]{"-stats","-args", String.valueOf(batch_size), String.valueOf(seq_length), - String.valueOf(num_features), String.valueOf(hidden_size), String.valueOf(debug), String.valueOf(seq), - output("1A"),output("1B"),output("2A"), output("2B"),output("3A"),output("3B"),"","","",""}; - int offset = 0; - if(backward){ - programArgs[14 + offset] = output("4A"); - programArgs[15 + offset] = output("4B"); - programArgs[16 + offset] = output("5A"); - programArgs[17 + offset] = output("5B"); + boolean bilstm = Objects.equals(testname, TEST_NAME3); + boolean bilstm_backwards = Objects.equals(testname, TEST_NAME4); + if(bilstm) + programArgs = new String[]{"-stats","-args", String.valueOf(batch_size), String.valueOf(seq_length), + String.valueOf(num_features), String.valueOf(hidden_size), String.valueOf(debug), String.valueOf(seq), + "src/test/resources/expected/BILSTM_OUT",output("1A")}; + else if(bilstm_backwards) + programArgs = new String[]{"-stats","-args", String.valueOf(batch_size), String.valueOf(seq_length), + String.valueOf(num_features), String.valueOf(hidden_size), String.valueOf(debug), String.valueOf(seq), + "src/test/resources/expected/BILSTM_back_dW", + "src/test/resources/expected/BILSTM_back_dc", + "src/test/resources/expected/BILSTM_back_dX", output("1A")}; + else{ + programArgs = new String[]{"-stats","-args", String.valueOf(batch_size), String.valueOf(seq_length), + String.valueOf(num_features), String.valueOf(hidden_size), String.valueOf(debug), String.valueOf(seq), + output("1A"),output("1B"),output("2A"), output("2B"),output("3A"),output("3B"),"","","",""}; + int offset = 0; + if(backward){ + programArgs[14 + offset] = output("4A"); + programArgs[15 + offset] = output("4B"); + programArgs[16 + offset] = output("5A"); + programArgs[17 + offset] = output("5B"); + } } + //output("4A"), output("4B"),output("5A"),output("5B") runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); // Compare results - extracted(precision,"1"); - extracted(precision,"2"); - extracted(precision,"3"); - if(backward){ - extracted(precision,"4"); - extracted(precision,"5"); + if(bilstm){ + Double max_error = (Double) readDMLScalarFromOutputDir("1A").values().toArray()[0]; + assert max_error < precision; + } else if (bilstm_backwards) { + HashMap<MatrixValue.CellIndex, Double> errors = readDMLMatrixFromOutputDir("1A"); + double[][] errors_ = TestUtils.convertHashMapToDoubleArray(errors); + assert errors_[0][0] < precision; + assert errors_[0][1] < precision; + assert errors_[0][2] < precision; + } else{ + extracted(precision,"1"); + extracted(precision,"2"); + extracted(precision,"3"); + if(backward){ + extracted(precision,"4"); + extracted(precision,"5"); + } } } catch(Exception ex) { diff --git a/src/test/resources/expected/BILSTM_OUT_10.0_5.0_2.0_6.0.csv b/src/test/resources/expected/BILSTM_OUT_10.0_5.0_2.0_6.0.csv new file mode 100644 index 0000000000..8d56568dc3 --- /dev/null +++ b/src/test/resources/expected/BILSTM_OUT_10.0_5.0_2.0_6.0.csv @@ -0,0 +1,10 @@ +0.014710,0.014903,0.015096,0.015289,0.015483,0.015677,0.087737,0.089270,0.090809,0.092355,0.093906,0.095464,0.027159,0.027664,0.028171,0.028678,0.029187,0.029697,0.041489,0.042511,0.043537,0.044567,0.045602,0.046640,0.050083,0.051167,0.052256,0.053349,0.054446,0.055547,0.020602,0.021458,0.022317,0.023179,0.024045,0.024915,0.097820,0.100007,0.102206,0.104415,0.106636,0.108867,0.010823,0.011590,0.012360,0.013133,0.013910,0.014689,0.210314,0.214749,0.219197,0.223656,0.228126,0.232606,0.0068 [...] +0.004987,0.005669,0.006354,0.007043,0.007736,0.008431,-0.050936,-0.050570,-0.050201,-0.049831,-0.049458,-0.049084,0.002538,0.003577,0.004623,0.005676,0.006735,0.007802,-0.034237,-0.033481,-0.032721,-0.031954,-0.031182,-0.030405,-0.001725,-0.000465,0.000806,0.002087,0.003378,0.004679,-0.020800,-0.019727,-0.018644,-0.017551,-0.016449,-0.015338,-0.009135,-0.007781,-0.006415,-0.005036,-0.003644,-0.002239,-0.010483,-0.009274,-0.008053,-0.006821,-0.005578,-0.004324,-0.021087,-0.019789,-0.01847 [...] +-0.003669,-0.002590,-0.001501,-0.000401,0.000710,0.001832,-0.083440,-0.083146,-0.082850,-0.082551,-0.082249,-0.081944,-0.016963,-0.015568,-0.014156,-0.012727,-0.011281,-0.009817,-0.066650,-0.065978,-0.065298,-0.064608,-0.063910,-0.063203,-0.034851,-0.033548,-0.032225,-0.030885,-0.029525,-0.028147,-0.046618,-0.045466,-0.044297,-0.043111,-0.041906,-0.040684,-0.055925,-0.054972,-0.054005,-0.053022,-0.052024,-0.051011,-0.027059,-0.025574,-0.024067,-0.022536,-0.020983,-0.019406,-0.075258,-0.0 [...] +-0.011280,-0.009894,-0.008486,-0.007057,-0.005606,-0.004133,-0.090803,-0.090535,-0.090262,-0.089982,-0.089697,-0.089406,-0.031927,-0.030335,-0.028713,-0.027060,-0.025375,-0.023660,-0.079466,-0.078873,-0.078266,-0.077644,-0.077008,-0.076357,-0.054679,-0.053437,-0.052167,-0.050869,-0.049541,-0.048184,-0.061748,-0.060630,-0.059485,-0.058312,-0.057111,-0.055881,-0.074301,-0.073576,-0.072833,-0.072069,-0.071286,-0.070483,-0.039531,-0.037918,-0.036267,-0.034577,-0.032850,-0.031083,-0.086480,-0 [...] +-0.017879,-0.016270,-0.014626,-0.012946,-0.011230,-0.009478,-0.091233,-0.091038,-0.090833,-0.090618,-0.090392,-0.090155,-0.043006,-0.041352,-0.039651,-0.037903,-0.036106,-0.034262,-0.083699,-0.083218,-0.082716,-0.082193,-0.081648,-0.081081,-0.065732,-0.064626,-0.063481,-0.062297,-0.061072,-0.059806,-0.069926,-0.068921,-0.067877,-0.066795,-0.065672,-0.064510,-0.080649,-0.080113,-0.079551,-0.078963,-0.078349,-0.077708,-0.048549,-0.046927,-0.045249,-0.043517,-0.041728,-0.039883,-0.087736,-0 [...] +-0.023507,-0.021756,-0.019953,-0.018096,-0.016187,-0.014224,-0.089191,-0.089104,-0.089001,-0.088881,-0.088744,-0.088589,-0.050853,-0.049242,-0.047568,-0.045828,-0.044023,-0.042152,-0.083994,-0.083654,-0.083286,-0.082888,-0.082460,-0.082001,-0.071218,-0.070295,-0.069323,-0.068299,-0.067223,-0.066094,-0.073696,-0.072855,-0.071964,-0.071023,-0.070030,-0.068983,-0.081785,-0.081433,-0.081048,-0.080630,-0.080177,-0.079689,-0.054733,-0.053193,-0.051581,-0.049896,-0.048137,-0.046303,-0.085915,-0 [...] +-0.028217,-0.026396,-0.024503,-0.022539,-0.020503,-0.018393,-0.086095,-0.086136,-0.086153,-0.086146,-0.086114,-0.086056,-0.056071,-0.054584,-0.053017,-0.051367,-0.049633,-0.047812,-0.082358,-0.082178,-0.081961,-0.081706,-0.081411,-0.081076,-0.073236,-0.072523,-0.071750,-0.070915,-0.070016,-0.069051,-0.074702,-0.074055,-0.073348,-0.072578,-0.071745,-0.070846,-0.080510,-0.080344,-0.080137,-0.079888,-0.079597,-0.079261,-0.058636,-0.057247,-0.055770,-0.054202,-0.052541,-0.050786,-0.082896,-0 [...] +-0.032065,-0.030238,-0.028321,-0.026311,-0.024209,-0.022012,-0.082508,-0.082687,-0.082835,-0.082950,-0.083032,-0.083078,-0.059194,-0.057889,-0.056487,-0.054985,-0.053381,-0.051673,-0.079720,-0.079709,-0.079653,-0.079550,-0.079398,-0.079195,-0.073109,-0.072619,-0.072059,-0.071427,-0.070719,-0.069934,-0.073978,-0.073541,-0.073034,-0.072454,-0.071798,-0.071065,-0.078043,-0.078062,-0.078033,-0.077955,-0.077824,-0.077640,-0.060740,-0.059549,-0.058254,-0.056852,-0.055341,-0.053717,-0.079351,-0 [...] +-0.035114,-0.033338,-0.031452,-0.029452,-0.027338,-0.025108,-0.078695,-0.079016,-0.079300,-0.079542,-0.079742,-0.079897,-0.060680,-0.059594,-0.058396,-0.057084,-0.055652,-0.054099,-0.076546,-0.076707,-0.076815,-0.076867,-0.076862,-0.076795,-0.071659,-0.071395,-0.071052,-0.070627,-0.070117,-0.069518,-0.072176,-0.071953,-0.071653,-0.071270,-0.070801,-0.070241,-0.074961,-0.075162,-0.075308,-0.075397,-0.075425,-0.075390,-0.061448,-0.060484,-0.059404,-0.058202,-0.056875,-0.055417,-0.075577,-0 [...] +-0.037430,-0.035753,-0.033945,-0.032004,-0.029927,-0.027711,-0.074796,-0.075260,-0.075680,-0.076051,-0.076370,-0.076635,-0.060906,-0.060060,-0.059092,-0.057995,-0.056764,-0.055395,-0.073088,-0.073418,-0.073689,-0.073897,-0.074038,-0.074108,-0.069399,-0.069356,-0.069228,-0.069010,-0.068697,-0.068284,-0.069705,-0.069696,-0.069602,-0.069417,-0.069136,-0.068754,-0.071561,-0.071938,-0.072254,-0.072506,-0.072690,-0.072800,-0.061091,-0.060371,-0.059523,-0.058542,-0.057421,-0.056155,-0.071722,-0 [...] diff --git a/src/test/resources/expected/BILSTM_OUT_2.0_5.0_2.0_2.0.csv b/src/test/resources/expected/BILSTM_OUT_2.0_5.0_2.0_2.0.csv new file mode 100644 index 0000000000..f41d9d951d --- /dev/null +++ b/src/test/resources/expected/BILSTM_OUT_2.0_5.0_2.0_2.0.csv @@ -0,0 +1,2 @@ +-0.0008102472922156037,0.004475520796032085,0.013402370700926657,0.018963985931974403,0.00505769470045857,0.0103317410641519,0.01371376281187445,0.01946222037947412,0.00868029572955103,0.01412995103648799,0.013907246679382276,0.019817538526131495,0.01100126480807401,0.01667737845412945,0.013905051818258757,0.019931162308030353,0.012561928573754893,0.018478239501463187,0.013696937209290468,0.01983034187308233 +0.006984309848334325,0.012997351829468727,0.01644707946083668,0.023116036511879712,0.010572171341473383,0.016790164307334054,0.017198878412656035,0.02404833005593171,0.012904071043027234,0.01939843672945603,0.018150876232426,0.025132196447129197,0.014470853555467619,0.021232571094286924,0.019456166597086797,0.026464869949517387,0.015588629793489147,0.022605586056502838,0.02159823966183599,0.028510847479637776 diff --git a/src/test/resources/expected/BILSTM_OUT_3.0_5.0_2.0_2.0.csv b/src/test/resources/expected/BILSTM_OUT_3.0_5.0_2.0_2.0.csv new file mode 100644 index 0000000000..b21020e451 --- /dev/null +++ b/src/test/resources/expected/BILSTM_OUT_3.0_5.0_2.0_2.0.csv @@ -0,0 +1,3 @@ +0.003596,0.003673,0.006775,0.007063,0.005016,0.005202,0.006847,0.007296,0.006068,0.006420,0.006623,0.007190,0.006900,0.007446,0.005909,0.006513,0.007604,0.008355,0.004392,0.004881 +0.004590,0.005182,0.009138,0.010459,0.006625,0.007592,0.009063,0.010508,0.008054,0.009328,0.008604,0.010093,0.009116,0.010658,0.007510,0.008896,0.009957,0.011746,0.005380,0.006385 +0.005577,0.006685,0.011475,0.013838,0.008219,0.009971,0.011257,0.013705,0.010021,0.012222,0.010565,0.012983,0.011309,0.013854,0.009096,0.011269,0.012285,0.015119,0.006360,0.007882 diff --git a/src/test/resources/expected/BILSTM_OUT_6.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_OUT_6.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..74a50ea584 --- /dev/null +++ b/src/test/resources/expected/BILSTM_OUT_6.0_5.0_6.0_4.0.csv @@ -0,0 +1,6 @@ +0.024943,0.025497,0.026052,0.026608,0.914151,0.917172,0.920082,0.922884,0.068794,0.071153,0.073530,0.075924,0.741737,0.748891,0.755851,0.762621,0.168919,0.174827,0.180772,0.186749,0.437013,0.447595,0.458066,0.468420,0.386854,0.397169,0.407401,0.417542,0.198658,0.207800,0.216999,0.226247,0.707591,0.716998,0.726116,0.734952,0.073417,0.078819,0.084306,0.089875 +0.085608,0.092330,0.099164,0.106103,0.964359,0.966456,0.968423,0.970267,0.263409,0.277537,0.291668,0.305780,0.918110,0.922998,0.927574,0.931856,0.585035,0.600644,0.615770,0.630410,0.725904,0.739089,0.751653,0.763613,0.865295,0.873093,0.880428,0.887327,0.385226,0.403933,0.422348,0.440435,0.957506,0.960482,0.963241,0.965799,0.133168,0.145166,0.157338,0.169651 +0.144558,0.157822,0.171257,0.184822,0.973806,0.975793,0.977623,0.979308,0.433433,0.455155,0.476355,0.496989,0.953600,0.957358,0.960785,0.963911,0.784918,0.798973,0.812108,0.824369,0.841472,0.852943,0.863533,0.873303,0.942730,0.947521,0.951885,0.955861,0.521229,0.544547,0.566947,0.588403,0.974501,0.976761,0.978813,0.980678,0.187298,0.205228,0.223242,0.241256 +0.197199,0.216187,0.235218,0.254202,0.978632,0.980601,0.982383,0.983995,0.554198,0.579318,0.603248,0.625968,0.966867,0.970131,0.973044,0.975646,0.868427,0.879957,0.890408,0.899877,0.893423,0.903253,0.912087,0.920023,0.963795,0.967509,0.970812,0.973752,0.614087,0.639358,0.663083,0.685281,0.980172,0.982253,0.984108,0.985764,0.233491,0.256278,0.278915,0.301275 +0.241743,0.265383,0.288819,0.311917,0.982011,0.983958,0.985688,0.987226,0.637209,0.663629,0.688213,0.711002,0.973897,0.976889,0.979506,0.981798,0.908074,0.917802,0.926392,0.933976,0.920554,0.929219,0.936821,0.943492,0.972995,0.976208,0.979007,0.981447,0.678209,0.704178,0.728029,0.749855,0.983577,0.985560,0.987297,0.988820,0.271643,0.298398,0.324721,0.350450 +0.278388,0.305860,0.332843,0.359166,0.984653,0.986558,0.988220,0.989671,0.695022,0.721737,0.746045,0.768075,0.978421,0.981225,0.983631,0.985699,0.929645,0.938175,0.945531,0.951880,0.936529,0.944388,0.951133,0.956926,0.978276,0.981192,0.983682,0.985811,0.723594,0.749689,0.773164,0.794203,0.986086,0.987983,0.989614,0.991019,0.302754,0.332923,0.362373,0.390899 diff --git a/src/test/resources/expected/BILSTM_back_dW_1.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dW_1.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..1ed259da9a --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dW_1.0_5.0_6.0_4.0.csv @@ -0,0 +1,22 @@ +0.070529,0.078134,0.085978,0.094048,0.026878,0.029445,0.032092,0.034815,0.096973,0.107043,0.117484,0.128286,0.318642,0.360038,0.401099,0.441797 +0.078229,0.087092,0.096242,0.105666,0.028870,0.031691,0.034601,0.037596,0.105126,0.116481,0.128259,0.140452,0.394856,0.449983,0.504799,0.559268 +0.085928,0.096049,0.106506,0.117285,0.030862,0.033937,0.037111,0.040378,0.113279,0.125918,0.139034,0.152618,0.471069,0.539929,0.608498,0.676739 +0.093628,0.105006,0.116770,0.128904,0.032854,0.036184,0.039620,0.043159,0.121433,0.135356,0.149810,0.164783,0.547283,0.629874,0.712197,0.794209 +0.101328,0.113964,0.127034,0.140523,0.034846,0.038430,0.042129,0.045940,0.129586,0.144794,0.160585,0.176949,0.623497,0.719820,0.815896,0.911680 +0.109027,0.122921,0.137298,0.152142,0.036838,0.040676,0.044639,0.048721,0.137739,0.154232,0.171361,0.189114,0.699710,0.809765,0.919596,1.029151 +0.058053,0.062447,0.066971,0.071617,0.024251,0.025724,0.027237,0.028788,0.096976,0.103386,0.109998,0.116805,0.226568,0.254681,0.282592,0.310286 +0.059830,0.064353,0.069008,0.073789,0.025007,0.026530,0.028094,0.029698,0.099869,0.106483,0.113304,0.120326,0.232253,0.260930,0.289395,0.317633 +0.061613,0.066266,0.071054,0.075970,0.025765,0.027338,0.028954,0.030610,0.102763,0.109582,0.116613,0.123852,0.237968,0.267212,0.296235,0.325021 +0.063401,0.068185,0.073107,0.078160,0.026524,0.028148,0.029815,0.031524,0.105656,0.112681,0.119924,0.127381,0.243713,0.273527,0.303111,0.332447 +0.135788,0.162056,0.189975,0.219483,0.022535,0.025877,0.029399,0.033090,0.123017,0.147810,0.174322,0.202523,0.752083,0.874535,0.995975,1.116117 +0.143100,0.170591,0.199803,0.230670,0.024227,0.027771,0.031506,0.035419,0.131337,0.157399,0.185258,0.214885,0.787554,0.915453,1.042278,1.167734 +0.150412,0.179127,0.209631,0.241856,0.025919,0.029666,0.033613,0.037748,0.139657,0.166988,0.196195,0.227246,0.823025,0.956370,1.088582,1.219351 +0.157723,0.187662,0.219459,0.253043,0.027612,0.031561,0.035720,0.040076,0.147977,0.176577,0.207132,0.239607,0.858497,0.997288,1.134886,1.270968 +0.165035,0.196198,0.229287,0.264229,0.029304,0.033455,0.037827,0.042405,0.156297,0.186166,0.218068,0.251968,0.893968,1.038206,1.181189,1.322585 +0.172347,0.204733,0.239116,0.275415,0.030996,0.035350,0.039935,0.044733,0.164617,0.195755,0.229005,0.264330,0.929440,1.079123,1.227493,1.374202 +0.068596,0.073275,0.077995,0.082740,0.033222,0.035268,0.037379,0.039545,0.161444,0.170071,0.178953,0.188075,0.163605,0.178001,0.191806,0.205004 +0.071672,0.076626,0.081621,0.086641,0.034529,0.036706,0.038953,0.041259,0.166417,0.175521,0.184896,0.194525,0.171931,0.187064,0.201568,0.215428 +0.074772,0.080006,0.085279,0.090578,0.035842,0.038152,0.040536,0.042984,0.171395,0.180982,0.190855,0.200997,0.180357,0.196237,0.211452,0.225983 +0.077895,0.083411,0.088968,0.094549,0.037160,0.039604,0.042127,0.044720,0.176375,0.186450,0.196826,0.207487,0.188877,0.205516,0.221450,0.236664 +0.769965,0.895735,1.026405,1.161881,0.199212,0.224631,0.250939,0.278111,0.815333,0.943782,1.077542,1.216557,7.621365,8.994547,10.369926,11.747077 +0.731165,0.853555,0.982814,1.118641,0.169214,0.189468,0.210708,0.232859,0.832005,0.958900,1.093663,1.236131,3.547149,4.091761,4.630351,5.161702 diff --git a/src/test/resources/expected/BILSTM_back_dW_10.0_5.0_2.0_6.0.csv b/src/test/resources/expected/BILSTM_back_dW_10.0_5.0_2.0_6.0.csv new file mode 100644 index 0000000000..de100b35e3 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dW_10.0_5.0_2.0_6.0.csv @@ -0,0 +1,18 @@ +-0.974498,-0.966921,-0.954301,-0.936537,-0.913530,-0.885186,-0.281082,-0.272333,-0.262528,-0.251656,-0.239709,-0.226675,-0.971205,-0.961223,-0.946260,-0.926230,-0.901050,-0.870640,4.775329,5.544087,6.323654,7.113853,7.914483,8.725320 +-0.987186,-0.977746,-0.963041,-0.942968,-0.917425,-0.886317,-0.284642,-0.275316,-0.264886,-0.253341,-0.240670,-0.226865,-0.983811,-0.971960,-0.954911,-0.932575,-0.904868,-0.871707,4.960390,5.811510,6.673917,7.547426,8.431831,9.326902 +0.083548,0.086380,0.089386,0.092567,0.095926,0.099465,0.031184,0.031820,0.032452,0.033078,0.033700,0.034318,0.094921,0.097927,0.101075,0.104367,0.107803,0.111386,-0.040848,0.045735,0.132688,0.220011,0.307702,0.395761 +0.082347,0.085215,0.088264,0.091498,0.094917,0.098524,0.030865,0.031536,0.032207,0.032877,0.033548,0.034219,0.093865,0.096924,0.100134,0.103500,0.107021,0.110700,-0.029946,0.058736,0.147805,0.237258,0.327096,0.417317 +0.081062,0.083967,0.087062,0.090349,0.093830,0.097507,0.030520,0.031226,0.031937,0.032653,0.033373,0.034099,0.092727,0.095838,0.099114,0.102555,0.106163,0.109940,-0.018757,0.072042,0.163244,0.254848,0.346852,0.439254 +0.079693,0.082636,0.085778,0.089120,0.092664,0.096413,0.030147,0.030889,0.031642,0.032403,0.033175,0.033957,0.091502,0.094669,0.098011,0.101529,0.105226,0.109105,-0.007278,0.085658,0.179013,0.272786,0.366974,0.461577 +0.078236,0.081219,0.084409,0.087807,0.091417,0.095240,0.029746,0.030525,0.031320,0.032129,0.032953,0.033792,0.090190,0.093413,0.096823,0.100421,0.104210,0.108191,0.004497,0.099588,0.195115,0.291076,0.387469,0.484292 +0.076689,0.079713,0.082952,0.086409,0.090086,0.093984,0.029316,0.030134,0.030971,0.031829,0.032706,0.033603,0.088788,0.092070,0.095549,0.099229,0.103111,0.107197,0.016573,0.113837,0.211555,0.309723,0.408340,0.507403 +-1.029869,-1.029191,-1.022595,-1.009946,-0.991111,-0.965963,-0.296143,-0.289570,-0.281868,-0.273021,-0.263015,-0.251834,-1.008123,-1.006597,-0.999453,-0.986576,-0.967853,-0.943177,4.969865,5.883939,6.811237,7.751578,8.704751,9.670517 +-1.046185,-1.045174,-1.038028,-1.024606,-1.004775,-0.978402,-0.301007,-0.294202,-0.286217,-0.277033,-0.266636,-0.255010,-1.024354,-1.022479,-1.014772,-1.001116,-0.981396,-0.955500,5.132196,6.105564,7.092637,8.093226,9.107116,10.134063 +0.081927,0.082015,0.082208,0.082508,0.082916,0.083436,0.031775,0.031463,0.031107,0.030707,0.030262,0.029771,0.093241,0.093127,0.093052,0.093019,0.093029,0.093082,-0.163661,-0.127389,-0.090835,-0.054000,-0.016882,0.020521 +0.080494,0.080571,0.080760,0.081062,0.081481,0.082018,0.031339,0.031044,0.030710,0.030337,0.029925,0.029473,0.091908,0.091786,0.091714,0.091695,0.091729,0.091818,-0.153025,-0.114522,-0.075722,-0.036623,0.002776,0.042475 +0.078968,0.079034,0.079220,0.079526,0.079957,0.080513,0.030873,0.030595,0.030283,0.029939,0.029561,0.029149,0.090482,0.090354,0.090286,0.090282,0.090342,0.090468,-0.142072,-0.101315,-0.060243,-0.018855,0.022849,0.064869 +0.077346,0.077403,0.077586,0.077898,0.078341,0.078919,0.030376,0.030115,0.029827,0.029512,0.029170,0.028799,0.088962,0.088827,0.088766,0.088778,0.088866,0.089031,-0.130795,-0.087760,-0.044392,-0.000691,0.043342,0.087709 +0.075625,0.075673,0.075855,0.076174,0.076632,0.077231,0.029846,0.029604,0.029341,0.029056,0.028750,0.028422,0.087343,0.087205,0.087150,0.087180,0.087297,0.087504,-0.119190,-0.073852,-0.028163,0.017875,0.064263,0.111002 +0.073804,0.073844,0.074026,0.074352,0.074826,0.075450,0.029283,0.029060,0.028823,0.028570,0.028301,0.028017,0.085625,0.085483,0.085436,0.085486,0.085634,0.085884,-0.107250,-0.059584,-0.011551,0.036849,0.085617,0.134753 +-1.268892,-1.082524,-0.874005,-0.643091,-0.389547,-0.113142,-0.355980,-0.298311,-0.235823,-0.168465,-0.096184,-0.018931,-1.260560,-1.073681,-0.865068,-0.634502,-0.381769,-0.106664,18.506062,26.742341,35.026258,43.357272,51.734803,60.158231 +-1.631585,-1.598323,-1.543224,-1.465990,-1.366333,-1.243969,-0.486387,-0.463261,-0.434887,-0.401197,-0.362124,-0.317602,-1.623169,-1.588220,-1.531923,-1.454011,-1.354220,-1.232293,16.233144,22.162589,28.139980,34.164819,40.236562,46.354617 diff --git a/src/test/resources/expected/BILSTM_back_dW_4.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dW_4.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..128efabd67 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dW_4.0_5.0_6.0_4.0.csv @@ -0,0 +1,22 @@ +12.936883602135635,14.996597145849965,17.036169100819556,19.0437351822756,5.617465973302987,6.2317238666466315,6.814256084468765,7.363050375933502,4.13615718959323,5.322089181803516,6.535496662833028,7.767044539965967,15.887144216556935,18.5686047367474,21.009431468096636,23.209475923937248 +13.117319864139198,15.2132758516673,17.289288032202396,19.333311447413628,5.686438930068359,6.311415468966765,6.904445215154332,7.463477240457396,4.205857872858232,5.41196633457022,6.646238436199083,7.899210146206036,16.242605828490543,19.00280100640935,21.518983942612753,23.7909319131795 +13.297756126142762,15.42995455748464,17.542406963585226,19.622887712551655,5.7554118868337305,6.391107071286898,6.9946343458398985,7.563904104981289,4.275558556123234,5.501843487336924,6.75698020956514,8.031375752446104,16.59806744042416,19.43699727607128,22.02853641712888,24.37238790242176 +13.478192388146331,15.646633263301977,17.79552589496805,19.912463977689683,5.824384843599103,6.4707986736070335,7.0848234765254645,7.664330969505183,4.345259239388235,5.5917206401036275,6.8677219829311955,8.163541358686174,16.953529052357762,19.871193545733227,22.538088891644996,24.953843891664008 +13.658628650149891,15.863311969119318,18.048644826350895,20.202040242827707,5.893357800364475,6.550490275927165,7.175012607211031,7.764757834029076,4.414959922653238,5.681597792870333,6.978463756297251,8.295706964926243,17.30899066429137,20.305389815395166,23.047641366161127,25.535299880906273 +13.839064912153455,16.079990674936653,18.301763757733717,20.491616507965734,5.962330757129847,6.630181878247298,7.265201737896599,7.86518469855297,4.484660605918239,5.771474945637036,7.089205529663306,8.427872571166311,17.66445227622498,20.739586085057102,23.557193840677243,26.116755870148527 +4.491549072659619,4.941675901365776,5.346030331234405,5.707873580828774,3.881604857080315,4.264889499698043,4.617372611782375,4.939671325258881,0.9052100865545737,1.1192460848599635,1.3225322645509994,1.514966463276804,1.6843170012525002,2.0021254970206375,2.288985672118411,2.5478092513421515 +4.667917059037986,5.1399998110398535,5.564559664320334,5.9448982445187015,3.9764303759112254,4.3738815347384445,4.740128505740977,5.075708943703384,0.9555151913143274,1.182088274085417,1.3974719478211692,1.6015108681173886,1.782885179796651,2.1143287309914034,2.4129124432745632,2.681759626250272 +4.841012659795874,5.334981380222909,5.779715497803135,6.178551925299636,4.0679771499989075,4.4793736038062075,4.859196756766566,5.207903479846765,1.005516618821403,1.2445936709736036,1.4720588364292773,1.6877017511397943,1.8818995277785113,2.227107285340091,2.537534568777468,2.816515885378705 +5.0105346028512034,5.52626915499285,5.9910994624540255,6.40839280753547,4.156245056234984,4.581344206647357,4.974533163755848,5.33618713980193,1.0550946957720662,1.3066107930434483,1.546110321598384,1.7733263676772004,1.9810966396711513,2.3401786374235374,2.6625556075645997,2.9517721168033146 +40.90870134060394,45.970813770583796,51.016294540872636,55.996440641646586,15.12927991627741,15.9018087074381,16.616332402995702,17.26904690317329,13.343463190261284,17.207468263839836,21.16902751271969,25.192309383099158,54.20186854642668,61.41634897938442,67.91460334870204,73.71322804766548 +41.55591757961296,46.70558429974983,51.83960339203193,56.908672606345675,15.395222584606191,16.186429912222692,16.91920427722267,17.589662307735804,13.59479710493191,17.53152127661213,21.568383611269425,25.66910864160666,55.37821297381305,62.78436388134744,69.46400605934778,75.43363688139101 +42.20313381862196,47.44035482891586,52.662912243191236,57.820904571044764,15.661165252934971,16.471051117007285,17.222076151449635,17.91027771229832,13.846131019602534,17.85557428938443,21.967739709819156,26.145907900114153,56.55455740119941,64.15237878331045,71.01340876999352,77.15404571511652 +42.85035005763096,48.17512535808189,53.48622109435055,58.73313653574384,15.92710792126375,16.755672321791874,17.5249480256766,18.230893116860834,14.097464934273155,18.17962730215672,22.367095808368887,26.62270715862165,57.73090182858578,65.52039368527346,72.56281148063925,78.87445454884205 +43.497566296639974,48.90989588724792,54.30952994550986,59.64536850044293,16.19305058959253,17.040293526576466,17.82781989990357,18.55150852142335,14.34879884894378,18.503680314929017,22.76645190691862,27.099506417129156,58.90724625597214,66.88840858723648,74.11221419128498,80.59486338256757 +44.144782535648986,49.644666416413955,55.13283879666916,60.55760046514201,16.458993257921307,17.324914731361055,18.130691774130533,18.872123925985868,14.600132763614404,18.827733327701313,23.165808005468357,27.57630567563665,60.0835906833585,68.2564234891995,75.66161690193071,82.31527221629311 +16.306124380827892,16.726574686464154,17.0865100118092,17.391054547714763,15.012898481602297,15.453965219648776,15.849449770909983,16.200370804397974,3.531756833692445,4.388776910429992,5.214932721606371,6.00963959735659,7.429683216204809,8.009387453296622,8.5279070214638,8.990759161926682 +16.916190137570716,17.37012097712369,17.75966921308232,18.090186058117144,15.352184612463917,15.818868788884696,16.23879082443568,16.612840756068834,3.7162405335358497,4.61981067274067,5.4913673309767885,6.330162649994514,7.846285163279279,8.455867819883418,9.000207266410552,9.485297875847042 +17.515030019795375,18.00280397510521,18.422370271265628,18.77929783026212,15.680402495176644,16.172581541543607,16.61687734636932,17.014047258251512,3.899558122906833,4.84952087261453,5.766380239266191,6.649220578991063,8.264606609785506,8.904599050937504,9.47525414869919,9.983043175805927 +18.101466449230152,18.623329986145247,19.073217577648304,19.456905963062095,15.997467965586697,16.51494662847186,16.98348107885388,17.403692794581847,4.08128011765337,5.077364972130296,6.039318351572684,6.966053205002741,8.683661298829808,9.354566162834415,9.952015457626622,10.482955050585977 +18.04362620035643,21.6678705817338,25.311893138283175,28.957626513802722,6.89729567653721,7.969160232013336,9.018913068556692,10.042686452389326,6.970068326500181,8.98771527667038,11.074177336605532,13.216560624006892,35.546161193360845,43.41962696619416,50.955247451612195,58.145598924225574 +64.72162390090068,73.47705291660326,82.3308851159307,91.22319646990839,26.59426683287795,28.46212047845905,30.287187422696668,32.061540456251656,25.13339146706239,32.40530127722957,39.93560985497329,47.67992585074989,117.63444273863638,136.80149019630164,154.9402710645735,172.04088337255266 diff --git a/src/test/resources/expected/BILSTM_back_dW_5.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dW_5.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..19dd73af3c --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dW_5.0_5.0_6.0_4.0.csv @@ -0,0 +1,22 @@ +158.126965,168.813970,179.516336,190.087946,24.146276,24.078055,23.951793,23.775922,164.642916,174.609275,184.794886,195.072743,230.357922,240.485051,248.982859,255.932549 +160.135964,171.013423,181.910024,192.677937,24.475919,24.420321,24.306355,24.142478,166.699824,176.851298,187.228396,197.702671,235.483266,246.229353,255.325141,262.851943 +162.144964,173.212877,184.303713,195.267929,24.805562,24.762587,24.660917,24.509034,168.756732,179.093321,189.661907,200.332599,240.608611,251.973656,261.667423,269.771338 +164.153963,175.412331,186.697401,197.857920,25.135206,25.104853,25.015478,24.875590,170.813641,181.335344,192.095417,202.962526,245.733956,257.717959,268.009705,276.690732 +166.162962,177.611785,189.091090,200.447911,25.464849,25.447119,25.370040,25.242146,172.870549,183.577368,194.528927,205.592454,250.859301,263.462261,274.351987,283.610126 +168.171961,179.811238,191.484778,203.037903,25.794492,25.789384,25.724602,25.608702,174.927457,185.819391,196.962438,208.222381,255.984646,269.206564,280.694269,290.529520 +20.065691,19.495811,19.000532,18.571341,8.625845,8.417489,8.224307,8.044844,52.312153,51.324577,50.409493,49.561645,18.276246,18.990773,19.701446,20.404787 +21.400569,20.786313,20.247919,19.777136,9.138390,8.929353,8.734040,8.551222,54.474584,53.514837,52.623407,51.795231,19.229419,19.928257,20.625086,21.316212 +22.727420,22.070641,21.490759,20.979811,9.644478,9.435653,9.239034,9.053618,56.591177,55.662493,54.797863,53.992407,20.188467,20.872196,21.555659,22.234959 +24.041526,23.344258,22.724705,22.175214,10.142601,9.934849,9.737730,9.550470,58.657776,57.763090,56.928145,56.148233,21.150880,21.820287,22.491052,23.159086 +209.696346,222.831777,235.490837,247.489134,23.750950,23.350016,22.860667,22.305285,190.748583,202.789163,214.662051,226.206364,247.693436,251.517412,253.591148,254.121108 +211.795026,225.088948,237.904962,250.056838,24.025922,23.627889,23.140675,22.586806,192.797468,204.990508,217.017300,228.715392,250.958662,254.954110,257.178679,257.840080 +213.893706,227.346119,240.319087,252.624543,24.300893,23.905762,23.420684,22.868327,194.846353,207.191853,219.372548,231.224420,254.223887,258.390807,260.766210,261.559052 +215.992386,229.603290,242.733211,255.192247,24.575864,24.183635,23.700693,23.149848,196.895237,209.393198,221.727796,233.733448,257.489113,261.827504,264.353741,265.278024 +218.091067,231.860461,245.147336,257.759951,24.850835,24.461508,23.980702,23.431370,198.944122,211.594543,224.083045,236.242476,260.754339,265.264201,267.941272,268.996995 +220.189747,234.117632,247.561461,260.327656,25.125807,24.739381,24.260711,23.712891,200.993007,213.795889,226.438293,238.751504,264.019564,268.700898,271.528803,272.715967 +19.176470,18.301613,17.506796,16.785689,8.068953,7.810065,7.556151,7.309009,51.321707,50.249678,49.200566,48.175763,12.146408,11.946547,11.768291,11.606968 +20.576595,19.630206,18.766872,17.980687,8.607228,8.342886,8.081426,7.825134,53.623189,52.561229,51.516417,50.490718,12.926300,12.689274,12.477871,12.286930 +21.962171,20.946495,20.016636,19.167132,9.136797,8.867936,8.599811,8.335191,55.868716,54.819984,53.782586,52.759044,13.708976,13.435469,13.191479,12.971377 +23.327167,22.244769,21.250693,20.339946,9.655690,9.383227,9.109318,8.837209,58.052502,57.019888,55.992798,54.974291,14.491691,14.182639,13.906849,13.658242 +200.899908,219.945375,239.368849,258.999134,32.964320,34.226581,35.456180,36.655597,205.690835,224.202315,243.351031,262.992756,512.534484,574.430266,634.228199,691.939404 +209.868024,225.717107,241.412486,256.770445,27.497130,27.787300,28.000891,28.152110,204.888464,220.134518,235.524834,250.902788,326.522570,343.669719,358.753112,371.897191 diff --git a/src/test/resources/expected/BILSTM_back_dW_6.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dW_6.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..810a56edbb --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dW_6.0_5.0_6.0_4.0.csv @@ -0,0 +1,22 @@ +1.851675,1.956805,2.061872,2.165451,0.278356,0.276834,0.275171,0.273479,1.862333,1.961998,2.063933,2.166931,2.812464,2.943940,3.062348,3.168445 +1.880127,1.988154,2.096203,2.202833,0.283464,0.282241,0.280884,0.279505,1.891188,1.993719,2.098635,2.204715,2.930453,3.078765,3.213840,3.336429 +1.908578,2.019502,2.130534,2.240215,0.288572,0.287648,0.286597,0.285530,1.920043,2.025440,2.133337,2.242499,3.048442,3.213590,3.365332,3.504413 +1.937029,2.050851,2.164866,2.277598,0.293680,0.293054,0.292310,0.291555,1.948899,2.057160,2.168038,2.280282,3.166432,3.348415,3.516824,3.672398 +1.965481,2.082200,2.199197,2.314980,0.298787,0.298461,0.298023,0.297581,1.977754,2.088881,2.202740,2.318066,3.284421,3.483240,3.668316,3.840382 +1.993932,2.113548,2.233528,2.352362,0.303895,0.303868,0.303735,0.303606,2.006609,2.120602,2.237442,2.355850,3.402410,3.618065,3.819808,4.008366 +0.260278,0.255907,0.252770,0.250722,0.109457,0.107699,0.106233,0.105039,0.608792,0.600205,0.592987,0.587050,0.383954,0.411103,0.438183,0.465109 +0.275947,0.271083,0.267499,0.265048,0.115482,0.113703,0.112211,0.110986,0.633848,0.625510,0.618523,0.612800,0.399458,0.426847,0.454181,0.481373 +0.291530,0.286196,0.282185,0.279351,0.121440,0.119652,0.118143,0.116895,0.658439,0.650392,0.643674,0.638204,0.415070,0.442708,0.470306,0.497770 +0.306975,0.301198,0.296786,0.293588,0.127314,0.125528,0.124012,0.122752,0.682521,0.674803,0.668391,0.663208,0.430766,0.458665,0.486536,0.514280 +2.206753,2.338535,2.467515,2.592021,0.258274,0.254782,0.251100,0.247410,1.980847,2.100873,2.221331,2.340841,2.999460,3.143682,3.274015,3.391570 +2.233652,2.367915,2.499433,2.626513,0.262588,0.259282,0.255791,0.252297,2.008482,2.130959,2.253963,2.376099,3.061863,3.212475,3.348965,3.472438 +2.260550,2.397295,2.531351,2.661005,0.266903,0.263783,0.260483,0.257184,2.036116,2.161045,2.286596,2.411357,3.124267,3.281269,3.423915,3.553306 +2.287449,2.426676,2.563269,2.695496,0.271218,0.268284,0.265174,0.262072,2.063750,2.191130,2.319228,2.446615,3.186671,3.350062,3.498865,3.634175 +2.314348,2.456056,2.595186,2.729988,0.275532,0.272784,0.269866,0.266959,2.091384,2.221216,2.351860,2.481873,3.249074,3.418855,3.573814,3.715043 +2.341246,2.485436,2.627104,2.764480,0.279847,0.277285,0.274557,0.271846,2.119018,2.251302,2.384492,2.517132,3.311478,3.487648,3.648764,3.795911 +0.256402,0.250037,0.244891,0.240815,0.114050,0.112328,0.110874,0.109678,0.691338,0.682229,0.674274,0.667423,0.268812,0.278563,0.288046,0.297182 +0.273158,0.266232,0.260583,0.256061,0.120648,0.118941,0.117496,0.116305,0.718776,0.710038,0.702434,0.695917,0.284328,0.294327,0.304060,0.313446 +0.289783,0.282326,0.276201,0.271256,0.127160,0.125481,0.124056,0.122882,0.745638,0.737316,0.730107,0.723966,0.299972,0.310234,0.320233,0.329880 +0.306214,0.298261,0.291691,0.286350,0.133566,0.131928,0.130535,0.129388,0.771870,0.764006,0.757231,0.751506,0.315711,0.326254,0.336534,0.346455 +2.845127,3.134858,3.433135,3.738234,0.510783,0.540698,0.571276,0.602528,2.885524,3.172076,3.470166,3.778397,11.798924,13.482492,15.149191,16.798430 +2.689869,2.938024,3.191776,3.449165,0.431455,0.450072,0.469150,0.488719,2.763423,3.008577,3.263219,3.525809,6.240362,6.879329,7.494987,8.086832 diff --git a/src/test/resources/expected/BILSTM_back_dX_1.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dX_1.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..db45e205e6 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dX_1.0_5.0_6.0_4.0.csv @@ -0,0 +1 @@ +-6.184786,-2.517676,1.149434,4.816544,8.483654,12.150764,-4.142743,-1.774661,0.59342,2.961501,5.329583,7.697664,-3.066285,-1.405285,0.255715,1.916714,3.577714,5.238714,-3.13755,-1.448026,0.241499,1.931024,3.620549,5.310074,-4.22913,-1.833551,0.562028,2.957607,5.353186,7.748765 diff --git a/src/test/resources/expected/BILSTM_back_dX_10.0_5.0_2.0_6.0.csv b/src/test/resources/expected/BILSTM_back_dX_10.0_5.0_2.0_6.0.csv new file mode 100644 index 0000000000..1d2a1248b6 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dX_10.0_5.0_2.0_6.0.csv @@ -0,0 +1,10 @@ +-17.314463,-2.134628,-10.771288,-1.483716,-8.212245,-1.204358,-8.698973,-1.199260,-12.463537,-1.519624 +-7.055268,-0.764809,-4.305086,-0.442786,-3.233885,-0.315695,-3.424976,-0.332942,-4.979341,-0.518065 +-3.060885,-0.264587,-1.768603,-0.090104,-1.302262,-0.027054,-1.475745,-0.050561,-2.370300,-0.177971 +-1.683596,-0.095104,-0.876676,0.027760,-0.596005,0.071093,-0.738573,0.047979,-1.360171,-0.051837 +-1.051295,-0.018189,-0.465382,0.080352,-0.263679,0.115287,-0.383707,0.092841,-0.861081,0.008350 +-0.697419,0.024189,-0.237346,0.108116,-0.079386,0.138076,-0.182923,0.116487,-0.570314,0.042283 +-0.474049,0.050315,-0.096307,0.123918,0.032565,0.150079,-0.057626,0.129692,-0.382380,0.063385 +-0.321777,0.067515,-0.003032,0.133012,0.104030,0.155846,0.025466,0.136980,-0.252364,0.077256 +-0.212433,0.079261,0.061265,0.137940,0.150773,0.157721,0.082668,0.140576,-0.158145,0.086625 +-0.131004,0.087401,0.106661,0.140085,0.181519,0.157056,0.122866,0.141705,-0.087616,0.092973 diff --git a/src/test/resources/expected/BILSTM_back_dX_4.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dX_4.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..6a027dbb83 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dX_4.0_5.0_6.0_4.0.csv @@ -0,0 +1,4 @@ +-12.162430917772921,-5.540870051836107,1.0806908141007046,7.702251680037516,14.323812545974327,20.94537341191114,-14.709464972930116,-7.1474182887656,0.41462839539891605,7.976675079563432,15.538721763727946,23.10076844789246,-24.17801089550484,-11.546386918816069,1.0852370578727006,13.716861034561468,26.348485011250236,38.980108987939005,-42.74673686491701,-19.397052469239377,3.95263192643827,27.302316322115914,50.652000717793555,74.00168511347121,-67.41692710745932,-28.96590064663489,9. [...] +-17.210226748992667,-7.890504579114665,1.4292175907633387,10.74893976064134,20.06866193051934,29.388384100397342,-13.367563324657157,-6.660404007599363,0.046755309458430705,6.753914626516224,13.461073943574018,20.16823326063181,-15.002930477522591,-7.880554190147432,-0.758177902772272,6.364198384602888,13.486574671978047,20.608950959353205,-25.446026015836406,-12.559626533458221,0.3267729489199703,13.213172431298158,26.099571913676346,38.985971396054545,-41.839277769953064,-19.0296784293 [...] +-18.90411574232276,-9.017895363445724,0.8683250154313117,10.75454539430835,20.640765773185386,30.526986152062417,-12.882417284567019,-6.690496730368308,-0.49857617616959726,5.6933443780291135,11.885264932227823,18.077185486426533,-12.035902299666835,-6.575279871764288,-1.1146574438617445,4.3459649840407995,9.806587411943346,15.26720983984589,-20.330598704364498,-10.602834918312633,-0.8750711322607708,8.852692653791092,18.58045643984295,28.30822022589482,-34.5301522081814,-16.494506599578 [...] +-20.074247583320112,-9.934899589542606,0.2044484042348993,10.343796398012405,20.483144391789914,30.622492385567416,-12.946254567688195,-6.947519337939333,-0.9487841081904715,5.04995112155839,11.048686351307252,17.047421581056113,-10.71767139933982,-5.940040737447426,-1.1624100755550322,3.6152205863373617,8.392851248229755,13.17048191012215,-18.315860665043907,-9.888344686874264,-1.4608287087046243,6.966687269465018,15.394203247634657,23.821719225804298,-31.53093031835856,-15.724985137582 [...] diff --git a/src/test/resources/expected/BILSTM_back_dX_5.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dX_5.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..fdaba6f603 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dX_5.0_5.0_6.0_4.0.csv @@ -0,0 +1,5 @@ +-146.566052,-59.476765,27.612522,114.701809,201.791096,288.880383,-100.209715,-42.592081,15.025552,72.643185,130.260819,187.878452,-72.124033,-32.764497,6.595038,45.954573,85.314108,124.673643,-68.746741,-31.885426,4.975889,41.837204,78.698519,115.559834,-90.536145,-39.475953,11.584239,62.644430,113.704622,164.764814 +-146.491915,-63.741314,19.009287,101.759887,184.510488,267.261089,-83.556927,-39.336713,4.883502,49.103717,93.323931,137.544146,-47.339836,-23.817603,-0.295371,23.226862,46.749094,70.271327,-68.389138,-33.507367,1.374404,36.256175,71.137946,106.019717,-134.583449,-61.181555,12.220338,85.622231,159.024125,232.426018 +-151.859355,-69.654046,12.551263,94.756573,176.961882,259.167191,-73.233071,-36.513966,0.205138,36.924243,73.643348,110.362453,-33.894321,-17.177599,-0.460876,16.255847,32.972570,49.689292,-64.224503,-32.569704,-0.914905,30.739895,62.394694,94.049493,-149.061364,-70.883641,7.294082,85.471804,163.649527,241.827249 +-157.115197,-75.331861,6.451475,88.234811,170.018146,251.801482,-66.670289,-34.215554,-1.760820,30.693914,63.148649,95.603383,-26.407100,-13.329840,-0.252579,12.824681,25.901942,38.979202,-60.777310,-31.384717,-1.992123,27.400471,56.793065,86.185658,-157.416881,-77.802436,1.812009,81.426453,161.040898,240.655343 +-162.057220,-80.671217,0.714786,82.100790,163.486793,244.872796,-62.098337,-32.313680,-2.529022,27.255635,57.040292,86.824950,-21.837080,-10.998426,-0.159773,10.678881,21.517534,32.356188,-57.762438,-30.125651,-2.488864,25.147923,52.784710,80.421497,-163.901225,-83.646601,-3.391977,76.862646,157.117270,237.371894 diff --git a/src/test/resources/expected/BILSTM_back_dX_6.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dX_6.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..f71a49d35c --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dX_6.0_5.0_6.0_4.0.csvdiff --git a/src/test/resources/expected/BILSTM_back_dc_1.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dc_1.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..2e387546d6 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dc_1.0_5.0_6.0_4.0.csv @@ -0,0 +1,4 @@ +15.656030,19.283222,22.910415,26.537608 +9.656310,11.929968,14.203627,16.477285 +4.061276,4.897442,5.738115,6.583245 +2.459186,2.940801,3.431978,3.932355 diff --git a/src/test/resources/expected/BILSTM_back_dc_10.0_5.0_2.0_6.0.csv b/src/test/resources/expected/BILSTM_back_dc_10.0_5.0_2.0_6.0.csv new file mode 100644 index 0000000000..47a037c93d --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dc_10.0_5.0_2.0_6.0.csv @@ -0,0 +1,40 @@ +12.657448,27.319347,41.981247,56.643146,71.305045,85.966944 +5.310812,11.392199,17.473586,23.554973,29.636360,35.717747 +2.412403,5.143665,7.874927,10.606190,13.337452,16.068714 +1.415714,2.997471,4.579228,6.160985,7.742741,9.324498 +0.958175,2.011801,3.065428,4.119054,5.172680,6.226306 +0.701361,1.458070,2.214780,2.971489,3.728198,4.484907 +0.538454,1.106518,1.674581,2.242645,2.810708,3.378772 +0.426611,0.865010,1.303409,1.741808,2.180207,2.618606 +0.345501,0.689827,1.034154,1.378481,1.722807,2.067134 +0.284275,0.557640,0.831004,1.104369,1.377734,1.651098 +9.025747,19.372598,29.719449,40.066300,50.413150,60.760001 +3.679332,7.847436,12.015540,16.183644,20.351748,24.519852 +1.882143,3.987340,6.092537,8.197734,10.302931,12.408128 +1.178239,2.476358,3.774477,5.072596,6.370715,7.668833 +0.824086,1.715768,2.607451,3.499133,4.390815,5.282497 +0.614535,1.265400,1.916266,2.567132,3.217998,3.868864 +0.477221,0.970102,1.462984,1.955866,2.448748,2.941629 +0.380886,0.762871,1.144855,1.526840,1.908825,2.290810 +0.309946,0.610300,0.910654,1.211008,1.511362,1.811716 +0.255793,0.493935,0.732077,0.970218,1.208360,1.446502 +2.831496,5.592951,8.359919,11.132397,13.910381,16.693867 +1.413871,2.501787,3.596573,4.698232,5.806768,6.922184 +0.814929,1.253037,1.695907,2.143550,2.595974,3.053187 +0.587623,0.810515,1.036873,1.266711,1.500044,1.736882 +0.470062,0.599641,0.731841,0.866678,1.004168,1.144326 +0.395309,0.476197,0.559084,0.643988,0.730924,0.819908 +0.341726,0.394309,0.448405,0.504031,0.561202,0.619934 +0.300407,0.335432,0.371581,0.408869,0.447308,0.486911 +0.267003,0.290692,0.315192,0.340512,0.366662,0.393651 +0.239114,0.255308,0.272058,0.289371,0.307252,0.325705 +2.152371,4.091381,6.039898,7.997922,9.965453,11.942490 +1.084800,1.809535,2.541151,3.279655,4.025055,4.777360 +0.696462,1.022416,1.353034,1.688331,2.028320,2.373015 +0.527120,0.703715,0.883697,1.067085,1.253898,1.444153 +0.430877,0.536899,0.645461,0.756583,0.870287,0.986590 +0.366400,0.433800,0.503116,0.574368,0.647578,0.722764 +0.318747,0.363070,0.408826,0.456034,0.504712,0.554877 +0.281289,0.311060,0.341880,0.373766,0.406733,0.440795 +0.250618,0.270901,0.291927,0.313710,0.336259,0.359584 +0.224785,0.238760,0.253236,0.268219,0.283715,0.299728 diff --git a/src/test/resources/expected/BILSTM_back_dc_4.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dc_4.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..894d5cbbb5 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dc_4.0_5.0_6.0_4.0.csv @@ -0,0 +1,16 @@ +20.36934151390274,25.085959116469013,29.802576719035294,34.51919432160157 +32.63346021571337,40.339397512606304,48.045334809499245,55.75127210639217 +33.75772600790142,41.874843467788835,49.991960927676246,58.10907838756367 +33.49366863268975,41.699399273579964,49.90512991447017,58.110860555360404 +162.03792624378303,200.168922070143,238.29991789650296,276.43091372286284 +93.506048495375,115.91352100653513,138.3209935176953,160.72846602885548 +71.41399794763137,88.84557159517253,106.27714524271374,123.70871889025491 +60.25018022565458,75.24252452349502,90.23486882133547,105.22721311917589 +4.987250166599223,6.270461320040244,7.564180272585769,8.868348948783296 +7.760796239648255,9.818780713867707,11.931008050806975,14.09521846608693 +8.963288468664414,11.192882610687583,13.495403935348238,15.865124250406879 +10.683747887926046,13.108265732382582,15.617740788867172,18.203010964096276 +39.154996152864655,48.727873383281604,58.4937322111716,68.44511251094345 +24.61362655767161,30.30708409001326,36.150238252902305,42.12873100922723 +22.268613156650268,26.998978353834126,31.864387313229287,36.84684948890825 +22.95932431906764,27.476475895888868,32.1260472498092,36.885915401679945 diff --git a/src/test/resources/expected/BILSTM_back_dc_5.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dc_5.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..3f8b588359 --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dc_5.0_5.0_6.0_4.0.csv @@ -0,0 +1,20 @@ +375.082775,461.954106,548.825436,635.696766 +346.646236,428.571811,510.497386,592.422962 +336.627179,417.669957,498.712735,579.755513 +327.937595,408.337457,488.737320,569.137183 +320.089567,399.963919,479.838271,559.712623 +199.484930,246.465443,293.445957,340.426471 +300.053376,372.035019,444.016663,515.998306 +314.246712,391.011993,467.777273,544.542553 +314.133586,392.243629,470.353672,548.463715 +311.271507,389.968059,468.664611,547.361163 +94.226079,116.323971,138.544432,160.885803 +89.256040,106.328531,123.757244,141.521060 +98.158094,112.811389,127.799029,143.079952 +112.744662,126.203529,139.933781,153.886333 +130.958102,143.954305,157.152270,170.500888 +51.401631,60.854358,70.492563,80.309254 +84.909530,97.953935,111.299507,124.913519 +103.876574,116.812421,130.012710,143.430921 +123.340306,136.164690,149.189838,162.364395 +144.683995,157.630195,170.679572,183.774896 diff --git a/src/test/resources/expected/BILSTM_back_dc_6.0_5.0_6.0_4.0.csv b/src/test/resources/expected/BILSTM_back_dc_6.0_5.0_6.0_4.0.csv new file mode 100644 index 0000000000..06384c015d --- /dev/null +++ b/src/test/resources/expected/BILSTM_back_dc_6.0_5.0_6.0_4.0.csv @@ -0,0 +1,24 @@ +15.656030,19.283222,22.910415,26.537608 +6.327869,7.823737,9.319605,10.815473 +3.703454,4.595199,5.486943,6.378688 +2.546207,3.170511,3.794814,4.419117 +1.911000,2.387879,2.864759,3.341638 +1.519824,1.905342,2.290860,2.676378 +9.656310,11.929968,14.203627,16.477285 +4.502903,5.583090,6.663278,7.743465 +2.873793,3.575829,4.277866,4.979902 +2.081636,2.599276,3.116915,3.634555 +1.621965,2.032078,2.442191,2.852304 +1.330356,1.671781,2.013205,2.354630 +4.061276,4.897442,5.738115,6.583245 +1.663567,1.950958,2.244407,2.543572 +1.094313,1.244407,1.397937,1.554493 +0.883124,0.981019,1.080855,1.182284 +0.786662,0.859704,0.933815,1.008719 +0.738264,0.797399,0.856889,0.916488 +2.459186,2.940801,3.431978,3.932355 +1.277041,1.473094,1.673768,1.878572 +0.954827,1.071526,1.190689,1.311898 +0.822200,0.905308,0.989775,1.075271 +0.758475,0.824079,0.890242,0.956679 +0.726989,0.781930,0.836596,0.890721 diff --git a/src/test/scripts/functions/tensor/BILSTMBackwardTest.dml b/src/test/scripts/functions/tensor/BILSTMBackwardTest.dml new file mode 100644 index 0000000000..5ef3ac1c66 --- /dev/null +++ b/src/test/scripts/functions/tensor/BILSTMBackwardTest.dml @@ -0,0 +1,92 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +source("scripts/nn/layers/bilstm.dml") as bilstm +source("scripts/nn/layers/lstm.dml") as lstm + +batch_size = as.integer($1) +seq_length = as.integer($2) +num_features = as.integer($3) +hidden_size = as.integer($4) +debug = as.logical(as.integer($5)) +seq = as.logical(as.integer($6)) + +factor = 0.01 +input_range = matrix("0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 [...] +input = input_range*factor +lstmIn = matrix(input[,1:batch_size*seq_length*num_features], rows=batch_size,cols=(seq_length*num_features)) + +input = input - (num_features + hidden_size)*hidden_size*factor +W = matrix(input[,1:(num_features + hidden_size)*hidden_size*4],rows=num_features + hidden_size, cols=hidden_size*4) +b = matrix(1,rows=1, cols=4*hidden_size)*factor +out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor +c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor + +out0 = rbind(out0, out0) +c0 = rbind(c0, c0) + +[out, c, cache_out, cache_c, cache_ifog] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0) + +dc = matrix(0,rows=batch_size*2,cols=hidden_size) +if(batch_size == 5){ + dout = matrix(input_range[,1:batch_size*hidden_size*seq_length*2], rows=batch_size, cols=hidden_size*seq_length*2) +} else if(batch_size == 4) { + dout = matrix(0, rows=batch_size, cols=hidden_size*2) + dc = matrix(input_range[,1:batch_size*hidden_size*2], rows=batch_size*2, cols=hidden_size) +} else if(batch_size == 3) { + +} else { + dout = matrix(1, rows=batch_size, cols=hidden_size*2) + if(seq){ + dout = matrix(1, rows=batch_size, cols=hidden_size*seq_length*2) + } +} + +[dx, dw, db, dw_reverse, db_reverse, dout0, dc0] = bilstm::backward(dout, dc, lstmIn, W, W, b, b, seq_length,num_features, seq, out0, c0, cache_out, cache_c, cache_ifog) +#print(toString(dx)) +#print(toString(dout0)) +#print(toString(dc0)) + +#print(toString(dw)) +#print(toString(dw_reverse)) +#print(toString(db)) +#print(toString(db_reverse)) + +errors = matrix(0, rows=1, cols=3) +expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); +tmp = rbind(dw, dw_reverse, db, db_reverse) +error = expected - tmp +error = max(abs(error)) +errors[1,1] = error + +expected2 = read($8 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); +tmp = rbind(dout0, dc0) +error = expected2 - tmp +error = max(abs(error)) +errors[1,2] = error + +expected3 = read($9 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); +error = expected3 - dx +error = max(abs(error)) +errors[1,3] = error + +write(errors, $10, format="text"); + + diff --git a/src/test/scripts/functions/tensor/BILSTMForwardTest.dml b/src/test/scripts/functions/tensor/BILSTMForwardTest.dml new file mode 100644 index 0000000000..772c3bf1a0 --- /dev/null +++ b/src/test/scripts/functions/tensor/BILSTMForwardTest.dml @@ -0,0 +1,60 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +source("scripts/nn/layers/bilstm.dml") as bilstm +source("scripts/nn/layers/lstm.dml") as lstm + +batch_size = as.integer($1) +seq_length = as.integer($2) +num_features = as.integer($3) +hidden_size = as.integer($4) +debug = as.logical(as.integer($5)) +seq = as.logical(as.integer($6)) + +factor = 0.01 +input_range = matrix("0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 [...] +input = input_range*factor +lstmIn = matrix(input[,1:batch_size*seq_length*num_features], rows=batch_size,cols=(seq_length*num_features)) + +input = input - (num_features + hidden_size)*hidden_size*factor +W = matrix(input[,1:(num_features + hidden_size)*hidden_size*4],rows=num_features + hidden_size, cols=hidden_size*4) +if(batch_size == 2){ + b = (matrix(input_range[,1:4*hidden_size], rows=1, cols=4*hidden_size) - 2*hidden_size)*factor + c0 = (matrix(input_range[,1:2*batch_size*hidden_size], rows=batch_size*2, cols=hidden_size) - 2*hidden_size)*factor + out0 = (matrix(input_range[,1:2*batch_size*hidden_size], rows=batch_size*2, cols=hidden_size) + 2*hidden_size)*factor +} else { + b = matrix(1,rows=1, cols=4*hidden_size)*factor + out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor + c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor + c0 = rbind(c0, c0) + out0 = rbind(out0, out0) +} + +[out2, c2, cache_out2, cache_c2, cache_ifog2] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0) +expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); +if(seq == FALSE){ + expectedA = expected[,(seq_length-1)*hidden_size*2 + 1 : (seq_length-1)*hidden_size*2 + hidden_size] + expectedB = expected[, hidden_size + 1 : hidden_size*2] + expected = cbind(expectedA, expectedB) +} +error = expected - out2 +error = max(abs(error)) +#print(error) +write(error, $8, format="text");
