This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9de62d16f9 [SYSTEMDS-3651] ResNet residual block forward path
9de62d16f9 is described below
commit 9de62d16f9f72c8952a4bafaba4c51aba793309c
Author: MaximilianTUB <[email protected]>
AuthorDate: Thu Nov 9 00:18:16 2023 +0100
[SYSTEMDS-3651] ResNet residual block forward path
This patch implements the forward path for the basic residual block
of ResNet and applies differential testing comparing results with PyTorch.
The testing is done by, first, testing for the correct handling of
dimension throughout the residual bock. Secondly, we test for correct
computation of the actual values of the output by using the
torchvision.models.resnet.BasicBlock of PyTorch which is there
implementation of the basic block for the ResNets 18 and 34. We let
PyTorch randomly initialize all weights and then hard-code these
weights into the component test with the expected values as well.
Closes #1940
---
scripts/nn/networks/resnet.dml | 364 ++++++++++++
.../test/applications/nn/NNComponentTest.java | 5 +
.../scripts/applications/nn/component/resnet.dml | 629 +++++++++++++++++++++
src/test/scripts/applications/nn/util.dml | 44 ++
4 files changed, 1042 insertions(+)
diff --git a/scripts/nn/networks/resnet.dml b/scripts/nn/networks/resnet.dml
new file mode 100644
index 0000000000..ae3c043b91
--- /dev/null
+++ b/scripts/nn/networks/resnet.dml
@@ -0,0 +1,364 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/layers/batch_norm2d_old.dml") as bn2d
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/max_pool2d_builtin.dml") as mp2d
+source("scripts/nn/layers/global_avg_pool2d.dml") as ap2d
+source("scripts/nn/layers/affine.dml") as fc
+
+conv3x3_forward = function(matrix[double] X, matrix[double] filter,
+ int C_in, int C_out, int Hin, int Win,
+ int strideh, int stridew)
+ return(matrix[double] out, int Hout, int Wout) {
+ /*
+ * Simple 2D conv layer with 3x3 filter
+ */
+ # bias should not be applied
+ bias = matrix(0, C_out, 1)
+ [out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win,
+ 3, 3, strideh, stridew, 1, 1)
+}
+
+conv1x1_forward = function(matrix[double] X, matrix[double] filter,
+ int C_in, int C_out, int Hin, int Win,
+ int strideh, int stridew)
+ return(matrix[double] out, int Hout, int Wout) {
+ /*
+ * Simple 2D conv layer with 1x1 filter
+ */
+ # bias should not be applied
+ bias = matrix(0, C_out, 1)
+ [out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win,
+ 1, 1, strideh, stridew, 0, 0)
+}
+
+basic_block_forward = function(matrix[double] X, list[unknown] weights,
+ int C_in, int C_base, int Hin, int Win,
+ int strideh, int stridew, string mode,
+ list[unknown] ema_means_vars)
+ return (matrix[double] out, int Hout, int Wout,
+ list[unknown] ema_means_vars_upd) {
+ /*
+ * Computes the forward pass for a basic residual block.
+ * This basic residual block (with 2 3x3 conv layers of
+ * same channel size) is used in the smaller ResNets 18
+ * and 34.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C_in*Hin*Win).
+ * - weights: list of weights for all layers of res block
+ * with the following order/content:
+ * -> 1: Weights of conv 1, of shape (C_base, C_in*3*3).
+ * -> 2: Weights of batch norm 1, of shape (C_base, 1).
+ * -> 3: Bias of batch norm 1, of shape (C_base, 1).
+ * -> 4: Weights of conv 2, of shape (C_base, C_base*3*3).
+ * -> 5: Weights of batch norm 2, of shape (C_base, 1).
+ * -> 6: Bias of batch norm 2, of shape (C_base, 1).
+ * If the block should downsample X:
+ * -> 7: Weights of downsample conv, of shape (C_base, C_in*3*3).
+ * -> 8: Weights of downsample batch norm, of shape (C_base, 1).
+ * -> 9: Bias of downsample batch norm, of shape (C_base, 1).
+ * - C_in: Number of input channels.
+ * - C_base: Number of base channels for this block.
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - strideh: Stride over height (usually 1 or 2)..
+ * - stridew: Stride over width (usually same as strideh).
+ * - mode: 'train' or 'test' to indicate if the model is currently
+ * being trained or tested for badge normalization layers.
+ * See badge_norm2d.dml docs for more info.
+ * - ema_means_vars: List of exponential moving averages for mean
+ * and variance for badge normalization layers.
+ * -> 1: EMA for mean of badge norm 1, of shape (C_base, 1).
+ * -> 2: EMA for variance of badge norm 1, of shape (C_base, 1).
+ * -> 3: EMA for mean of badge norm 2, of shape (C_base, 1).
+ * -> 4: EMA for variance of badge norm 2, of shape (C_base, 1).
+ * If the block should downsample X:
+ * -> 5: EMA for mean of downs. badge norm, of shape (C_base, 1).
+ * -> 6: EMA for variance of downs. badge norm, of shape (C_base, 1).
+ *
+ * Outputs:
+ * - out: Output, of shape (N, C_base*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ * - ema_means_vars_upd: List of updated exponential moving averages
+ * for mean and variance of badge normalization layers.
+ */
+ downsample = strideh > 1 | stridew > 1 | C_in != C_base
+ # default values
+ mu_bn = 0.1
+ epsilon_bn = 1e-05
+
+ # get all params
+ W_conv1 = as.matrix(weights[1])
+ gamma_bn1 = as.matrix(weights[2])
+ beta_bn1 = as.matrix(weights[3])
+ W_conv2 = as.matrix(weights[4])
+ gamma_bn2 = as.matrix(weights[5])
+ beta_bn2 = as.matrix(weights[6])
+
+ ema_mean_bn1 = as.matrix(ema_means_vars[1])
+ ema_var_bn1 = as.matrix(ema_means_vars[2])
+ ema_mean_bn2 = as.matrix(ema_means_vars[3])
+ ema_var_bn2 = as.matrix(ema_means_vars[4])
+
+ if (downsample) {
+ # gather params for donwsampling
+ W_conv3 = as.matrix(weights[7])
+ gamma_bn3 = as.matrix(weights[8])
+ beta_bn3 = as.matrix(weights[9])
+ ema_mean_bn3 = as.matrix(ema_means_vars[5])
+ ema_var_bn3 = as.matrix(ema_means_vars[6])
+ }
+
+ # RESIDUAL PATH
+ # First convolutional layer
+ [out, Hout, Wout] = conv3x3_forward(X, W_conv1, C_in, C_base, Hin, Win,
+ strideh, stridew)
+ [out, ema_mean_bn1_upd, ema_var_bn1_upd, c_m, c_v] = bn2d::forward(out,
gamma_bn1,
+ beta_bn1, C_base, Hout, Wout,
+ mode, ema_mean_bn1,
ema_var_bn1,
+ mu_bn, epsilon_bn)
+ out = relu::forward(out)
+
+ # Second convolutional layer
+ [out, Hout, Wout] = conv3x3_forward(out, W_conv2, C_base, C_base, Hout,
+ Wout, 1, 1)
+ [out, ema_mean_bn2_upd, ema_var_bn2_upd, c_m, c_v] = bn2d::forward(out,
gamma_bn2,
+ beta_bn2, C_base, Hout, Wout,
+ mode, ema_mean_bn2,
ema_var_bn2,
+ mu_bn, epsilon_bn)
+
+ # IDENTITY PATH
+ identity = X
+ if (downsample) {
+ # downsample input
+ [identity, Hout, Wout] = conv1x1_forward(X, W_conv3, C_in, C_base,
+ Hin, Win, strideh, stridew)
+ [identity, ema_mean_bn3_upd, ema_var_bn3_upd, c_m, c_v] =
bn2d::forward(identity,
+ gamma_bn3, beta_bn3, C_base, Hout,
Wout,
+ mode, ema_mean_bn3, ema_var_bn3, mu_bn,
+ epsilon_bn)
+ }
+
+ out = relu::forward(out + identity)
+
+ ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd,
ema_mean_bn2_upd, ema_var_bn2_upd)
+ if (downsample) {
+ ema_means_vars_upd = append(ema_means_vars, ema_mean_bn3_upd)
+ ema_means_vars_upd = append(ema_means_vars, ema_var_bn3_upd)
+ }
+}
+
+basic_reslayer_forward = function(matrix[double] X, int Hin, int Win, int
blocks,
+ int strideh, int stridew, int C_in, int C_base,
+ list[unknown] blocks_weights, string mode,
+ list[unknown] ema_means_vars)
+ return (matrix[double] out, int Hout, int Wout,
+ list[unknown] ema_means_vars_upd) {
+ /*
+ * Executes the forward pass for a sequence of residual blocks
+ * with the same number of base channels, i.e. residual layer.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C_in*Hin*Win)
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - blocks: Number of residual blocks (bigger than 0).
+ * - strideh: Stride height for first conv layer of first block.
+ * - stridew: Stride width for first conv layer of first block.
+ * - C_in: Number of input channels.
+ * - C_base: Number of base channels of res layer.
+ * - blocks_weights: List of weights of each block.
+ * -> i: List of weights of block i with the content
+ * defined in the docs of basic_block_forward().
+ * -> length == blocks
+ * - mode: 'train' or 'test' to indicate if the model is currently
+ * being trained or tested for badge normalization layers.
+ * See badge_norm2d.dml docs for more info.
+ * - ema_means_vars: List of exponential moving averages for mean
+ * and variance for badge normalization layers of each block.
+ * -> i: List of EMAs of block i with the content defined
+ * in the docs of basic_block_forward().
+ * -> length == blocks
+ */
+ # default values
+ mu_bn = 0.1
+ epsilon_bn = 1e-05
+
+ # first block with provided stride
+ [out, Hout, Wout, emas1_upd] = basic_block_forward(X,
as.list(blocks_weights[1]),
+ C_in, C_base, Hin, Win, strideh,
stridew,
+ mode, as.list(ema_means_vars[1]))
+ ema_means_vars_upd = list(emas1_upd)
+
+ # other block
+ for (i in 2:blocks) {
+ current_weights = as.list(blocks_weights[i])
+ current_emas = as.list(ema_means_vars[i])
+ [out, Hout, Wout, current_emas_upd] = basic_block_forward(X=out,
+ weights=current_weights, C_in=C_base, C_base=C_base,
+ Hin=Hout, Win=Wout, strideh=1, stridew=1, mode=mode,
+ ema_means_vars=current_emas)
+ ema_means_vars_upd = append(ema_means_vars_upd, current_emas_upd)
+ }
+}
+
+resnet18_forward = function(matrix[double] X, int Hin, int Win,
+ list[unknown] model, string mode,
+ list[unknown] ema_means_vars)
+ return (matrix[double] out, list[unknown] ema_means_vars_upd) {
+ /*
+ * Forward pass of the ResNet 18 model as introduced in
+ * "Deep Residual Learning for Image Recognition" by
+ * Kaiming He et. al. and inspired by the PyTorch
+ * implementation.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C_in*Hin*Win).
+ * C_in = 3 is expected.
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - model: Weights and bias matrices of the model
+ * with the following order/content:
+ * -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7)
+ * -> 2: Weights of batch norm 1, of shape (64, 1).
+ * -> 3: Bias of batch norm 1, of shape (64, 1).
+ * -> 4: List of weights for first residual layer
+ * with 64 base channels.
+ * -> 5: List of weights for second residual layer
+ * with 128 base channels.
+ * -> 6: List of weights for third residual layer
+ * with 256 base channels.
+ * -> 7: List of weights for fourth residual layer
+ * with 512 base channels.
+ * List of residual layers 1, 2, 3 & 4 have
+ * the content/order:
+ * -> 1: List of weights for first residual
+ * block.
+ * -> 2: List of weights for second residual
+ * block.
+ * Each list of weights for a residual block
+ * must follow the same order as defined in
+ * the documentation of basic_block_forward().
+ * -> 8: Weights of fully connected layer, of shape (512, 1000)
+ * -> 9: Bias of fully connected layer, of shape (1, 1000)
+ * - mode: 'train' or 'test' to indicate if the model is currently
+ * being trained or tested for badge normalization layers.
+ * See badge_norm2d.dml docs for more info.
+ * - ema_means_vars: List of exponential moving averages for mean
+ * and variance for badge normalization layers.
+ * -> 1: EMA for mean of badge norm 1, of shape (64, 1).
+ * -> 2: EMA for variance of badge norm 1, of shape (64, 1).
+ * -> 3: List of EMA means and vars for residual layer 1.
+ * -> 4: List of EMA means and vars for residual layer 2.
+ * -> 5: List of EMA means and vars for residual layer 3.
+ * -> 6: List of EMA means and vars for residual layer 4.
+ * Lists for EMAs of layer 1, 2, 3 & 4 must have the
+ * following order:
+ * -> 1: List of EMA means and vars for residual block 1.
+ * -> 2: List of EMA means and vars for residual block 2.
+ * Each list of EMAs for a residual block
+ * must follow the same order as defined in
+ * the documentation of basic_block_forward().
+ * - NOTICE: The lists of the first blocks for layer 2, 3 and 4
+ * must include weights and EMAs for 1 extra conv layer
+ * and a batch norm layer for the downsampling on the
+ * identity path.
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, 1000)
+ * - ema_means_vars_upd: List of updated exponential moving averages
+ * for mean and variance of badge normalization layers. It follows
+ * the same exact structure as the input EMAs list.
+ */
+ # default values
+ mu_bn = 0.1
+ epsilon_bn = 1e-05
+
+ # extract model params
+ W_conv1 = as.matrix(model[1])
+ gamma_bn1 = as.matrix(model[2]); beta_bn1 = as.matrix(model[3])
+ weights_reslayer1 = as.list(model[4])
+ weights_reslayer2 = as.list(model[5])
+ weights_reslayer3 = as.list(model[6])
+ weights_reslayer4 = as.list(model[7])
+ W_fc = as.matrix(model[8])
+ b_fc = as.matrix(model[9])
+ ema_mean_bn1 = as.matrix(ema_means_vars[1]); ema_var_bn1 =
as.matrix(ema_means_vars[2])
+ emas_reslayer1 = as.list(ema_means_vars[3])
+ emas_reslayer2 = as.list(ema_means_vars[4])
+ emas_reslayer3 = as.list(ema_means_vars[5])
+ emas_reslayer4 = as.list(ema_means_vars[6])
+
+ # Convolutional 7x7 layer
+ C = 64
+ b_conv1 = matrix(0, rows=C, cols=1)
+ [out, Hout, Wout] = conv2d::forward(X=X, W=W_conv1, b=b_conv1, C=3,
+ Hin=Hin, Win=Win, Hf=7, Wf=7, strideh=2,
+ stridew=2, padh=3, padw=3)
+ # Batch Normalization
+ [out, ema_mean_bn1_upd, ema_var_bn1_upd, c_mean, c_var] =
bn2d::forward(X=out,
+ gamma=gamma_bn1, beta=beta_bn1, C=C, Hin=Hout,
+ Win=Wout, mode=mode, ema_mean=ema_mean_bn1,
+ ema_var=ema_var_bn1, mu=mu_bn,
+ epsilon=epsilon_bn)
+ # ReLU
+ out = relu::forward(X=out)
+ # Max Pooling 3x3
+ [out, Hout, Wout] = mp2d::forward(X=out, C=C, Hin=Hout, Win=Wout, Hf=3,
+ Wf=3, strideh=2, stridew=2, padh=1, padw=1)
+
+ # residual layer 1
+ [out, Hout, Wout, emas1_upd] = basic_reslayer_forward(X=out, Hin=Hout,
+ Win=Wout, blocks=2, strideh=1, stridew=1,
C_in=C,
+ C_base=64, blocks_weights=weights_reslayer1,
+ mode=mode, ema_means_vars=emas_reslayer1)
+ C = 64
+ # residual layer 2
+ [out, Hout, Wout, emas2_upd] = basic_reslayer_forward(X=out, Hin=Hout,
+ Win=Wout, blocks=2, strideh=2, stridew=2,
C_in=C,
+ C_base=128, blocks_weights=weights_reslayer2,
+ mode=mode, ema_means_vars=emas_reslayer2)
+ C = 128
+ # residual layer 3
+ [out, Hout, Wout, emas3_upd] = basic_reslayer_forward(X=out, Hin=Hout,
+ Win=Wout, blocks=2, strideh=2, stridew=2,
C_in=C,
+ C_base=256, blocks_weights=weights_reslayer3,
+ mode=mode, ema_means_vars=emas_reslayer3)
+ C = 256
+ # residual layer 4
+ [out, Hout, Wout, emas4_upd] = basic_reslayer_forward(X=out, Hin=Hout,
+ Win=Wout, blocks=2, strideh=2, stridew=2,
C_in=C,
+ C_base=512, blocks_weights=weights_reslayer4,
+ mode=mode, ema_means_vars=emas_reslayer4)
+ C = 512
+
+ # Global Average Pooling
+ [out, Hout, Wout] = ap2d::forward(X=out, C=C, Hin=Hout, Win=Wout)
+ # Affine
+ out = fc::forward(X=out, W=W_fc, b=b_fc)
+
+ ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd,
+ emas1_upd, emas2_upd, emas3_upd, emas4_upd)
+}
diff --git
a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
index 19fd805bf6..a41fc235f5 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
@@ -113,6 +113,11 @@ public class NNComponentTest extends TestFolder {
run("logcosh.dml");
}
+ @Test
+ public void resnet() {
+ run("resnet.dml");
+ }
+
@Override
protected void run(String name) {
super.run("component/" + name);
diff --git a/src/test/scripts/applications/nn/component/resnet.dml
b/src/test/scripts/applications/nn/component/resnet.dml
new file mode 100644
index 0000000000..9ac30db3c8
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/resnet.dml
@@ -0,0 +1,629 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/networks/resnet.dml") as resnet
+source("src/test/scripts/applications/nn/util.dml") as test_util
+
+
+basic_test_basic_block_forward = function(int C_in, int C_base, int strideh,
int stridew,
+ int Hin, int Win, int N, Boolean
downsample) {
+ /*
+ * Basic input tests for basic_block_forward() function
+ */
+ X = matrix(0, rows=N, cols=C_in*Hin*Win)
+ W_conv1 = matrix(1, rows=C_base, cols=C_in*3*3)
+ gamma_bn1 = matrix(1, rows=C_base, cols=1)
+ beta_bn1 = matrix(0, rows=C_base, cols=1)
+ W_conv2 = matrix(0, rows=C_base, cols=C_base*3*3)
+ gamma_bn2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2 = matrix(0, rows=C_base, cols=1)
+ weights = list(W_conv1, gamma_bn1, beta_bn1, W_conv2, gamma_bn2, beta_bn2)
+ if (downsample) {
+ W_conv3 = matrix(0, rows=C_base, cols=C_in*1*1)
+ gamma_bn3 = matrix(1, rows=C_base, cols=1)
+ beta_bn3 = matrix(0, rows=C_base, cols=1)
+ weights = append(weights, W_conv3)
+ weights = append(weights, gamma_bn3)
+ weights = append(weights, beta_bn3)
+ }
+
+
+ mode = "train"
+ ema_mean_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn1 = matrix(1, rows=C_base, cols=1)
+ ema_mean_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn2 = matrix(1, rows=C_base, cols=1)
+ ema_means_vars = list(ema_mean_bn1, ema_var_bn1, ema_mean_bn2, ema_var_bn2)
+ if (downsample) {
+ ema_mean_bn3 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn3 = matrix(1, rows=C_base, cols=1)
+ ema_means_vars = append(ema_means_vars, ema_mean_bn3)
+ ema_means_vars = append(ema_means_vars, ema_var_bn3)
+ }
+
+ [out, Hout, Wout, ema_means_vars_up] = resnet::basic_block_forward(X,
weights, C_in, C_base, Hin, Win,
+ strideh,
stridew, mode, ema_means_vars)
+ Hout_expected = Hin / strideh; Wout_expected = Win / strideh
+ out_cols_expected = C_base * Hout_expected * Wout_expected
+ if (Hout_expected != Hout | Wout_expected != Wout | out_cols_expected !=
ncol(out)) {
+ test_util::fail("Output shapes of basic_block_forward() are not as
expected!")
+ test_util::fail("Output shapes of basic_block_forward() are not as
expected!")
+ }
+}
+
+
+values_test_basic_block_forward_1 = function() {
+ /*
+ * Testing of values for forward pass of basic block against PyTorch.
+ */
+ strideh = 1; stridew = 1
+ C_in = 2; C_base = 2
+ Hin = 4; Win = 4
+ N = 2
+ X = matrix(1, rows=N, cols=C_in*Hin*Win)
+ W_conv1 = matrix("-0.13904892 0.12838013 0.08027278 -0.06143695
0.07755502 -0.16483936 0.06582125 0.00754158 0.10763083 0.04604699
0.03576668 -0.07599333 -0.06836346 0.19890614 0.01955454 -0.02767003
0.21198983 0.12785362
+ 0.04019578 -0.14636862 -0.02285126 -0.00971214
0.12590824 -0.06414033 -0.1034085 -0.23452668 -0.0999288 0.12418596
-0.03290591 0.02420332 0.17950852 0.00047226 0.13068716 -0.00955899
0.03092374 0.05555834",
+ rows=C_base, cols=C_in*3*3)
+ gamma_bn1 = matrix(1, rows=C_base, cols=1)
+ beta_bn1 = matrix(0, rows=C_base, cols=1)
+ W_conv2 = matrix(" 0.10092591 0.15790914 -0.17673795 0.10573213
0.13680543 -0.22161855 -0.10239416 0.10747905 0.03636803 -0.00693908
-0.19976966 -0.15770042 -0.23468268 0.1040463 0.08357517 0.0780759
0.21764557 -0.11318331
+ 0.1958775 0.00366694 0.05713235 -0.0768708
-0.04275537 -0.23076743 -0.029018 -0.02308315 -0.05915356 -0.12383241
0.16292028 0.20669906 -0.19045494 0.10580237 0.21305619 0.19072767
-0.19292024 0.15425198",
+ rows=C_base, cols=C_base*3*3)
+ gamma_bn2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2 = matrix(0, rows=C_base, cols=1)
+ weights = list(W_conv1, gamma_bn1, beta_bn1, W_conv2, gamma_bn2, beta_bn2)
+ mode = "train"
+
+ ema_mean_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_means_vars = list(ema_mean_bn1, ema_var_bn1, ema_mean_bn2, ema_var_bn2)
+
+ [out, Hout, Wout, ema_means_vars_up] = resnet::basic_block_forward(X,
weights, C_in, C_base, Hin, Win,
+ strideh,
stridew, mode, ema_means_vars)
+
+ out_expected = matrix("1.1192019 0.13680267 0.3965159 0.09488004
1.5221035 0.60268176 0.870202 0.3568437 0.19053036 2.023053 2.810772
2.7142973 1.4163418 2.2421117 0.22204155 0. 0.31268513 0.14521259
0.57843214 0.65275586 0. 0.02211368 0.4826215 0.65296173 1.2726448
0.6964331 1.6637247 1.2155424 2.3015604 3.9708042 1.6967016 0.40145046
+ 1.1192019 0.13680267 0.3965159 0.09488004
1.5221035 0.60268176 0.870202 0.3568437 0.19053036 2.023053 2.810772
2.7142973 1.4163418 2.2421117 0.22204155 0. 0.31268513 0.14521259
0.57843214 0.65275586 0. 0.02211368 0.4826215 0.65296173 1.2726448
0.6964331 1.6637247 1.2155424 2.3015604 3.9708042 1.6967016 0.40145046",
+ rows=N, cols=Hout*Wout*C_base)
+
+ test_util::check_all_close(out, out_expected, 0.00001)
+}
+
+
+values_test_basic_block_forward_2 = function() {
+ /*
+ * Testing of values for forward pass of basic block against PyTorch.
+ */
+ strideh = 2; stridew = 2
+ C_in = 2; C_base = 2
+ Hin = 4; Win = 4
+ N = 2
+ X = matrix(1, rows=N, cols=C_in*Hin*Win)
+ W_conv1 = matrix("-0.14026615 -0.06974511 0.21421503 0.00487083
-0.17600328 -0.05576494 0.08433063 -0.04809754 -0.0021321 -0.1935787
-0.04766957 0.15073563 0.14598249 -0.1946578 -0.01819092 -0.11103764
-0.01316494 -0.14351277
+ -0.0036971 -0.18704589 -0.09860466 0.20417325
-0.20006022 0.00959031 0.13883735 -0.11765242 -0.17820978 -0.03428984
-0.02357996 0.11326601 -0.22515622 0.2001556 -0.0103206 -0.0384565
0.13819869 -0.03230184",
+ rows=C_base, cols=C_in*3*3)
+ gamma_bn1 = matrix(1, rows=C_base, cols=1)
+ beta_bn1 = matrix(0, rows=C_base, cols=1)
+ W_conv2 = matrix(" 0.1952378 -0.13218941 0.20359151 0.21437167
0.20657437 0.07917522 -0.20072569 -0.16550082 0.14789648 0.03155191
0.10938872 -0.18765432 0.2069266 -0.0324703 0.14553984 -0.15199026
-0.01177226 0.05884366
+ -0.16591048 -0.11745082 0.11246873 0.21435808
0.000237 -0.02330555 0.03408287 -0.09445126 0.09905426 -0.022421
-0.01720028 -0.08738072 -0.13018131 0.2277623 -0.22259445 0.06712689
-0.08571149 0.17849205",
+ rows=C_base, cols=C_base*3*3)
+ gamma_bn2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2 = matrix(0, rows=C_base, cols=1)
+
+ # downsampling weights
+ W_conv3 = matrix("-0.44065592 -0.29570574
+ -0.60239863 0.43495506",
+ rows=C_base, cols=C_in*1*1)
+ gamma_bn3 = matrix(1, rows=C_base, cols=1)
+ beta_bn3 = matrix(0, rows=C_base, cols=1)
+
+ weights = list(W_conv1, gamma_bn1, beta_bn1, W_conv2, gamma_bn2, beta_bn2,
W_conv3, gamma_bn3, beta_bn3)
+ mode = "train"
+
+ ema_mean_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn3 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn3 = matrix(0, rows=C_base, cols=1)
+ ema_means_vars = list(ema_mean_bn1, ema_var_bn1, ema_mean_bn2,
ema_var_bn2, ema_mean_bn3, ema_var_bn3)
+
+ [out, Hout, Wout, ema_means_vars_up] = resnet::basic_block_forward(X,
weights, C_in, C_base, Hin, Win,
+ strideh,
stridew, mode, ema_means_vars)
+
+ out_expected = matrix("0. 0. 0.33147347 1.4695541 0.
0.9726007 0. 0.9382379
+ 0. 0. 0.33147347 1.4695541 0.
0.9726007 0. 0.9382379 ",
+ rows=N, cols=Hout*Wout*C_base)
+
+ test_util::check_all_close(out, out_expected, 0.0001)
+}
+
+
+values_test_basic_block_forward_3 = function() {
+ /*
+ * Testing of values for forward pass of basic block against PyTorch.
+ */
+ strideh = 2; stridew = 2
+ C_in = 2; C_base = 4
+ Hin = 4; Win = 4
+ N = 2
+ X = matrix(1, rows=N, cols=C_in*Hin*Win)
+ W_conv1 = matrix(" 0.23060845 0.01255383 0.10554366 -0.11032246
0.04110865 0.12300454 0.03407042 0.03677405 0.1022801 0.08667193
0.15104999 -0.08343385 -0.16137402 0.2004693 -0.15173802 -0.14732602
-0.16977917 -0.11108103
+ -0.02234201 -0.13497168 -0.15396744 -0.11581142
0.19164546 -0.02277191 0.1283987 -0.06767575 -0.05453977 -0.13944787
-0.16732863 0.08283444 -0.20333174 0.15993242 0.09864204 0.02714513
-0.05416261 -0.16831529
+ -0.02864511 0.17540906 0.10217993 0.16238032
-0.09703699 0.13528688 -0.12437965 0.22771217 0.13020585 0.06029092
0.03008728 0.08683048 0.18321039 -0.0570219 -0.04176761 -0.10389957
-0.21008791 0.21670572
+ -0.17345327 0.05192728 -0.137008 -0.16411738
0.2257642 0.1580276 -0.17054762 -0.15276143 0.18090443 -0.00081959
-0.20020181 -0.15055852 -0.08635178 0.11620347 -0.17689864 -0.17850053
0.14149593 0.05391319",
+ rows=C_base, cols=C_in*3*3)
+ gamma_bn1 = matrix(1, rows=C_base, cols=1)
+ beta_bn1 = matrix(0, rows=C_base, cols=1)
+ W_conv2 = matrix("-0.12502055 0.04218115 -0.01587057 0.02996665
0.03902064 -0.10042268 0.14614715 0.04653394 -0.16000232 0.16580684
-0.00197132 -0.12345098 0.16189565 -0.08695424 -0.02277309 -0.02887423
-0.06450109 -0.02360182 -0.11778401 -0.0124789 -0.14320037 -0.1436774
-0.00914766 -0.01130253 0.13241099 0.03841829 0.14280184 0.07521738
0.11815999 -0.13276237 0.12992252 -0.04222127 -0.03073458 0.01725562
0.15965255 0.10338821
+ -0.1398921 0.05605571 -0.13995157 -0.12097194
-0.1035642 0.03313832 0.13476823 0.02719207 0.02726786 -0.08288203
0.10799147 0.03092675 -0.01539116 -0.08172278 -0.05077231 0.03913508
0.02528121 0.08431648 0.16408543 -0.09090649 -0.09221806 0.00649713
0.08248532 0.05170746 -0.03424133 0.12494816 -0.03637959 0.01817816
0.10356762 0.03744942 -0.10864812 0.10180093 -0.04949838 -0.10033202
-0.10501622 -0.05735092
+ -0.05820473 -0.11734504 0.16419913 -0.05231454
-0.07497393 0.1414146 -0.07572757 0.12433673 0.11995722 -0.08965874
0.01813734 0.07857008 -0.01808423 -0.10376819 0.02973495 -0.06675623
0.12945338 0.11593701 -0.0270998 0.06052397 -0.09865837 0.05997723
-0.09147376 -0.13572678 0.04277535 0.06700458 0.10086431 0.0624107
0.01380444 0.02379382 -0.06924826 0.16204901 0.09410296 0.08837719
0.08246924 -0.02000479
+ -0.05427806 -0.06499916 -0.07410125 -0.09293389
-0.05310518 -0.14569238 0.0565486 0.06677081 0.11938848 0.140608
0.08632036 0.06824701 0.08058478 0.1592765 0.05660491 0.000421
0.08094937 0.02867652 0.00291246 -0.10608425 0.03485543 -0.1534006
-0.05454665 -0.12841377 0.15747286 0.15352447 0.00212862 -0.06103357
-0.05826634 -0.12362739 0.05188899 0.0780222 -0.08359687 -0.07310607
-0.04413005 0.16476734",
+ rows=C_base, cols=C_base*3*3)
+ gamma_bn2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2 = matrix(0, rows=C_base, cols=1)
+
+ # downsampling weights
+ W_conv3 = matrix("-0.434205 -0.51763165
+ -0.48000953 -0.4031318
+ 0.38555235 0.09546977
+ -0.35429037 -0.35176745",
+ rows=C_base, cols=C_in*1*1)
+ gamma_bn3 = matrix(1, rows=C_base, cols=1)
+ beta_bn3 = matrix(0, rows=C_base, cols=1)
+
+ weights = list(W_conv1, gamma_bn1, beta_bn1, W_conv2, gamma_bn2, beta_bn2,
W_conv3, gamma_bn3, beta_bn3)
+ mode = "train"
+
+ ema_mean_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn3 = matrix(0, rows=C_base, cols=1)
+ ema_var_bn3 = matrix(0, rows=C_base, cols=1)
+ ema_means_vars = list(ema_mean_bn1, ema_var_bn1, ema_mean_bn2,
ema_var_bn2, ema_mean_bn3, ema_var_bn3)
+
+ [out, Hout, Wout, ema_means_vars_up] = resnet::basic_block_forward(X,
weights, C_in, C_base, Hin, Win,
+ strideh,
stridew, mode, ema_means_vars)
+
+ out_expected = matrix("0. 1.576729 0. 0.14989257 0.
0. 1.5978024 0. 1.6740767 0. 0. 0.
0.89626473 1.0869961 0. 0.
+ 0. 1.576729 0. 0.14989257 0.
0. 1.5978024 0. 1.6740767 0. 0. 0.
0.89626473 1.0869961 0. 0. ",
+ rows=N, cols=Hout*Wout*C_base)
+
+ test_util::check_all_close(out, out_expected, 0.0001)
+}
+
+values_test_residual_layer_forward = function() {
+ Hin = 4; Win = 4;
+ C_in = 2; C_base = 4;
+ N = 2
+ blocks = 2
+ strideh = 2; stridew = 2
+ mode = "train"
+
+ X = matrix(seq(1, N*Hin*Win*C_in), rows=N, cols=C_in*Hin*Win)
+
+ # weights for block 1
+ W_conv1 = matrix(" 0.18088163 -0.1094396 0.16322954
+ -0.16315436 -0.21879527 -0.13885644
+ -0.17897591 -0.03466086 0.00447898
+
+ -0.02060366 0.07391576 -0.13932472
+ 0.19118743 -0.1424667 0.13759573
+ 0.01909433 -0.09354833 -0.04176301
+
+
+ 0.1272765 0.03356279 0.09614499
+ -0.19478372 0.00336599 -0.11832798
+ -0.167036 -0.21355109 0.09034546
+
+ 0.08218841 0.09808768 -0.13677552
+ 0.09094055 -0.2075183 -0.11292712
+ 0.01358929 0.06142236 0.15713598
+
+
+ -0.07507452 0.17012559 -0.00778644
+ -0.17894624 -0.16280204 -0.20480031
+ 0.17800058 -0.08244179 0.21474986
+
+ 0.13779633 -0.09157459 -0.16890329
+ -0.18981671 0.01013687 0.19278602
+ 0.12158521 0.14861913 0.16235258
+
+
+ -0.06423138 -0.08338779 0.03925912
+ 0.08258285 -0.10299681 -0.02878706
+ 0.09201117 0.1826442 0.04269032
+
+ 0.23345782 -0.19535708 -0.04301685
+ -0.20099604 0.13898961 -0.13627604
+ -0.02552247 0.19012617 0.11112581", rows=C_base,
cols=C_in*3*3)
+ gamma_bn1 = matrix(1, rows=C_base, cols=1)
+ beta_bn1 = matrix(0, rows=C_base, cols=1)
+ W_conv2 = matrix("-0.05845656 -0.10387342 -0.00347954
+ -0.10443704 0.07762128 0.08962621
+ -0.15254496 -0.11626967 -0.07093517
+
+ 0.04007399 0.08974029 -0.08184759
+ -0.03026156 -0.12850255 0.11625351
+ -0.06768056 0.04290299 -0.11035781
+
+ -0.08182923 -0.07303425 -0.14692405
+ 0.07654093 0.01254661 -0.08789502
+ 0.12506868 -0.03339644 0.09265955
+
+ -0.0010149 0.06057359 0.13991983
+ 0.07688661 0.0267771 0.0457534
+ -0.12409463 0.16171788 0.04298592
+
+
+ 0.04832537 -0.03725861 -0.06689632
+ -0.06220818 0.14289887 -0.15294018
+ 0.00162084 0.01681285 -0.16070186
+
+ -0.06003825 0.04543029 -0.10140823
+ 0.12588196 0.0608999 0.12295659
+ -0.06011419 -0.08439511 0.14588566
+
+ -0.06961578 0.03729546 0.10786648
+ -0.04095362 -0.02719462 0.01085408
+ 0.03948607 0.08386 -0.03532495
+
+ 0.00769222 0.04593278 -0.09789048
+ -0.03487056 0.05989157 -0.09724655
+ -0.15379873 -0.07006194 0.14572401
+
+
+ -0.15647712 0.08853446 -0.0191143
+ -0.15235935 0.09655853 0.02425753
+ -0.03920817 -0.13884489 -0.15873407
+
+ -0.16632351 -0.15446958 -0.02832182
+ 0.15606628 -0.07041912 0.09881757
+ -0.07226615 0.07085086 -0.03745939
+
+ -0.0293621 -0.14463028 -0.15513223
+ 0.10102965 -0.11223143 0.16495053
+ 0.00031853 0.158157 -0.13941693
+
+ 0.10985805 0.11699991 -0.06803058
+ -0.08846518 0.13454668 -0.07047473
+ 0.0816289 -0.03807278 -0.01125084
+
+
+ -0.1425793 0.04159175 0.0873092
+ -0.07806729 0.06501202 0.09965105
+ 0.06275028 0.07400332 0.09444918
+
+ -0.08856728 -0.09136113 -0.07333919
+ 0.04255192 -0.10251606 0.15050472
+ 0.07791793 0.10539879 -0.06219628
+
+ 0.12434496 0.16624264 0.08779152
+ 0.00117178 0.13169001 -0.04333049
+ -0.07304269 0.14722325 -0.06679092
+
+ -0.08179037 0.15500171 -0.00718816
+ -0.1278541 -0.08474605 -0.10129128
+ 0.07862437 -0.06843086 -0.04310509", rows=C_base,
cols=C_base*3*3)
+ gamma_bn2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2 = matrix(0, rows=C_base, cols=1)
+ # downsample parameters
+ W_conv3 = matrix("-0.41760927
+
+ -0.0473721
+
+
+ -0.14902914
+
+ -0.3806823
+
+
+ 0.04111391
+
+ 0.6815012
+
+
+ 0.16695487
+
+ 0.5603499 ", rows=C_base, cols=C_in*1*1)
+ gamma_bn3 = matrix(1, rows=C_base, cols=1)
+ beta_bn3 = matrix(0, rows=C_base, cols=1)
+ block1_weights = list(W_conv1, gamma_bn1, beta_bn1, W_conv2, gamma_bn2,
beta_bn2, W_conv3, gamma_bn3, beta_bn3)
+
+ # EMAS for block 1
+ ema_mean_bn1 = matrix(1, rows=C_base, cols=1)
+ ema_var_bn1 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn2 = matrix(1, rows=C_base, cols=1)
+ ema_var_bn2 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn3 = matrix(1, rows=C_base, cols=1)
+ ema_var_bn3 = matrix(0, rows=C_base, cols=1)
+ block1_EMAs = list(ema_mean_bn1, ema_var_bn1, ema_mean_bn2, ema_var_bn2,
ema_mean_bn3, ema_var_bn3)
+
+ # Weights for block 2
+ W_conv1_2 = matrix("-7.28789419e-02 7.32977241e-02 -1.16737187e-01
+ -1.09122857e-01 1.51534528e-02 -1.55087024e-01
+ 1.54908732e-01 -1.46781862e-01 5.23662418e-02
+
+ -1.45408154e-01 9.34596211e-02 -1.57577261e-01
+ -1.62250042e-01 -2.06526369e-02 1.58289358e-01
+ -1.38804317e-04 -1.37716874e-01 -1.25059336e-01
+
+ 5.44795990e-02 -1.56691819e-02 4.58848923e-02
+ -1.66475177e-01 -8.56291652e-02 9.89388674e-02
+ -3.30540538e-02 -4.98571396e-02 -1.48154795e-03
+
+ -1.63741872e-01 7.52360672e-02 -7.26198778e-02
+ 7.97830820e-02 -1.53044373e-01 1.05956629e-01
+ 1.42732725e-01 -4.19619307e-02 3.91238928e-03
+
+
+ -3.11080813e-02 6.24701530e-02 -8.18871707e-02
+ -2.80226916e-02 5.65987229e-02 -5.63029051e-02
+ 9.20275897e-02 -1.50979385e-01 3.14275920e-03
+
+ 4.63068485e-02 -7.43940473e-02 -1.23582803e-01
+ -4.84378934e-02 -1.62422940e-01 2.17949301e-02
+ 1.48192182e-01 -9.01084542e-02 9.67378765e-02
+
+ 1.82208419e-02 1.48985460e-01 -1.47735506e-01
+ 8.09304416e-03 1.12461001e-02 2.95447111e-02
+ -1.24866471e-01 3.81960124e-02 -6.68919683e-02
+
+ 5.13156503e-02 6.83855265e-02 -4.90674824e-02
+ -9.68660563e-02 3.40797305e-02 -1.13457203e-01
+ -2.67352313e-02 -8.27323049e-02 -7.86665529e-02
+
+
+ -1.18089557e-01 1.66068569e-01 4.50132638e-02
+ 4.65527624e-02 1.10370263e-01 1.24886349e-01
+ -9.42516923e-02 -1.62573516e-01 7.66497254e-02
+
+ 1.08407423e-01 -4.26756591e-02 -1.11639105e-01
+ -8.21658969e-02 2.47098356e-02 -7.98595399e-02
+ -1.54958516e-02 2.23536491e-02 -1.03785992e-02
+
+ -1.53956562e-02 1.13292173e-01 -2.27067471e-02
+ -2.12994069e-02 8.41291696e-02 1.61149070e-01
+ -1.32289156e-01 -7.05852732e-02 7.90221095e-02
+
+ -1.43424913e-01 -1.21421874e-01 -1.27822340e-01
+ -2.88078189e-02 -5.81898317e-02 -1.99964195e-02
+ 1.16435274e-01 -1.30379200e-03 -4.03594971e-02
+
+
+ -1.00988328e-01 6.64077997e-02 -1.31890640e-01
+ -1.35123342e-01 -1.37298390e-01 8.09081197e-02
+ 3.08579355e-02 7.35761523e-02 1.45316467e-01
+
+ 9.22436267e-02 3.85234505e-02 3.37007642e-02
+ 4.96874899e-02 2.56118029e-02 -2.25261599e-02
+ 6.15442246e-02 7.31151104e-02 1.47016749e-01
+
+ -7.32975453e-02 -6.77637309e-02 2.41905302e-02
+ -1.47778392e-01 -1.27045453e-01 -3.37420404e-03
+ -1.08250901e-01 -8.57535824e-02 -7.87658989e-03
+
+ -3.25922221e-02 -8.58756527e-02 7.44262338e-02
+ -1.37389064e-01 7.38748461e-02 -7.68387318e-02
+ 1.25040159e-01 -6.90028891e-02 -5.00665307e-02",
rows=C_base, cols=C_base*3*3)
+ gamma_bn1_2 = matrix(1, rows=C_base, cols=1)
+ beta_bn1_2 = matrix(0, rows=C_base, cols=1)
+ W_conv2_2 = matrix("-0.02658759 0.02687277 -0.11249679
+ 0.07110029 -0.1287723 0.14960568
+ 0.00347126 0.082368 0.1592095
+
+ 0.01121612 -0.01054858 -0.11533582
+ 0.00191922 -0.09345891 0.0201468
+ 0.16406216 0.1631505 -0.12568823
+
+ 0.10923527 0.0047278 0.08264925
+ 0.03556542 -0.11967081 -0.03904144
+ 0.06578751 -0.11364846 -0.06719196
+
+ -0.02243698 -0.15126523 -0.16504839
+ 0.03257662 0.03516276 -0.00604182
+ 0.00565775 0.15348013 -0.08239889
+
+
+ -0.10838002 0.03725785 -0.01738496
+ 0.01061849 0.00593401 -0.14623034
+ -0.04515064 0.11624385 0.07984154
+
+ 0.02820455 0.14632984 0.11314054
+ 0.1556999 0.06233911 0.04464059
+ 0.16220112 0.12387766 -0.01516332
+
+ 0.11520021 0.03099509 0.11585347
+ 0.0215022 0.09577711 0.12590684
+ 0.12393482 -0.00796187 -0.1233204
+
+ -0.01217443 -0.09484772 -0.13615403
+ -0.06195782 -0.10316825 -0.13738668
+ 0.04165971 -0.16430686 0.1249779
+
+
+ -0.06067413 0.09290363 -0.10419172
+ 0.04424816 0.16639026 -0.01638204
+ -0.01993459 0.16510008 -0.03844319
+
+ -0.06738343 -0.15954465 0.14164312
+ -0.09711097 0.04057109 -0.06419432
+ -0.15190187 0.02492356 0.14873762
+
+ 0.05357671 -0.02110486 -0.07781315
+ -0.12230659 0.13541014 0.04158884
+ 0.13525964 0.07432733 0.04886186
+
+ -0.04131328 0.05893086 0.08948417
+ 0.15411825 0.05368501 -0.13857502
+ 0.15523924 -0.1510986 0.01617928
+
+
+ 0.14878072 0.15607525 -0.12842798
+ 0.00907773 0.09931238 0.03955895
+ 0.04165536 -0.0382842 0.06571688
+
+ -0.0926128 -0.15800306 -0.0235941
+ 0.03582941 -0.13953064 0.03686035
+ -0.15508795 0.06028162 -0.15286762
+
+ 0.00642897 -0.01605938 -0.10140433
+ 0.09824272 0.13854371 0.13406266
+ 0.13023908 -0.10159403 -0.08961493
+
+ -0.0350284 0.08208416 0.11221837
+ -0.07019123 0.00895458 0.10123546
+ -0.04459848 0.15377314 0.04990514", rows=C_base,
cols=C_base*3*3)
+ gamma_bn2_2 = matrix(1, rows=C_base, cols=1)
+ beta_bn2_2 = matrix(0, rows=C_base, cols=1)
+ block2_weights = list(W_conv1_2, gamma_bn1_2, beta_bn1_2, W_conv2_2,
gamma_bn2_2, beta_bn2_2)
+
+ # EMAS for block 1
+ ema_mean_bn1_2 = matrix(1, rows=C_base, cols=1)
+ ema_var_bn1_2 = matrix(0, rows=C_base, cols=1)
+ ema_mean_bn2_2 = matrix(1, rows=C_base, cols=1)
+ ema_var_bn2_2 = matrix(0, rows=C_base, cols=1)
+ block2_EMAs = list(ema_mean_bn1_2, ema_var_bn1_2, ema_mean_bn2_2,
ema_var_bn2_2)
+
+ expected_Hout = 2
+ expected_Wout = 2
+ expected_out = matrix("1.9154322 1.3386528
+ 0. 0.17239538
+
+ 1.1703718 1.765771
+ 0. 0.48162544
+
+ 0. 0.
+ 0.5448834 0.
+
+ 0.5596865 0.
+ 1.7399819 0.
+
+
+ 0. 2.4118602
+ 0. 0.32017422
+
+ 0.45649305 1.0522902
+ 1.0418018 1.2802167
+
+ 3.914177 0.7292799
+ 0. 1.287521
+
+ 0. 0.71220785
+ 2.7913036 0.6481694 ", rows=2,
cols=C_base*expected_Hout*expected_Wout)
+
+ blocks_weights = list(block1_weights, block2_weights)
+ ema_means_vars = list(block1_EMAs, block2_EMAs)
+ [out, Hout, Wout, ema_means_vars_upd] = resnet::basic_reslayer_forward(X,
Hin, Win, blocks,
+ strideh,
stridew, C_in, C_base,
+ blocks_weights,
mode, ema_means_vars)
+
+ test_util::check_all_close(out, expected_out, 0.0001)
+}
+
+
+/*
+ * **** Basic Block Shape Handling Testing ****
+ */
+
+/*
+ * Test case 1:
+ * Basic block forward computation shouldn't raise errors
+ * when given valid inputs with valid shapes without having
+ * to downsample the input and the output should have the
+ * expected shape.
+ */
+basic_test_basic_block_forward(C_in=4, C_base=4, strideh=1, stridew=1, Hin=4,
Win=4, N=3, downsample=FALSE)
+
+/*
+ * Test case 2:
+ * Basic block forward computation shouldn't raise errors
+ * when given valid inputs with having to downsample inputs
+ * because of non-matching channels and the output should
+ * have the expected (downsampled) shapes.
+ */
+basic_test_basic_block_forward(C_in=2, C_base=4, strideh=1, stridew=1, Hin=4,
Win=4, N=3, downsample=TRUE)
+
+/*
+ * Test case 3:
+ * Basic block forward computation shouldn't raise errors
+ * when given valid inputs with having to downsample inputs
+ * because of stride bigger than 1 and the output should
+ * have the expected (downsampled) shapes.
+ */
+basic_test_basic_block_forward(C_in=4, C_base=4, strideh=2, stridew=2, Hin=4,
Win=4, N=3, downsample=TRUE)
+
+/*
+ * **** Basic Block Value Testing ****
+ * In these test cases, we compare the forward pass
+ * computation of a basic residual block against the
+ * PyTorch implementation. We calculate the PyTorch
+ * values with the NN module
+ * torchvision.models.resnet.BasicBlock and then
+ * hard-code the randomly initialized weights and
+ * biases and the expected output computed by PyTorch
+ * into this file.
+ */
+
+/*
+ * Test case 1:
+ * A simple forward pass of basic block with same
+ * input and output channels and the same input and
+ * output dimensions, i.e. stride 1, in train mode.
+ */
+values_test_basic_block_forward_1()
+
+/*
+ * Test case 2:
+ * A simple forward pass of basic block with
+ * downsampling through a stride of 2 but same number
+ * of channels in train mode.
+ */
+values_test_basic_block_forward_2()
+
+/*
+ * Test case 2:
+ * A simple forward pass of basic block with
+ * downsampling through non-matching channel sizes
+ * and different number of channels in train mode.
+ */
+values_test_basic_block_forward_3()
+
+/*
+ * *** Residual Layer Value Testing ***
+ * A residual layer is a sequence of residual blocks
+ * which all have the same number of base channels. In
+ * residual networks, there are 4 different residual layer.
+ * With this test, we test the correct computation of the
+ * shape and values of the output by comparing it to PyTorches
+ * residual layers. We modified the PyTorch implementation to
+ * extract the residual layer.
+ */
+
+/*
+ * Test case 1:
+ * A residual layer forward pass with downsampling through
+ * a stride of 2 and non-matching input and base channels
+ * of 2 and 4.
+ */
+values_test_residual_layer_forward()
diff --git a/src/test/scripts/applications/nn/util.dml
b/src/test/scripts/applications/nn/util.dml
index e32a885e64..a983928dee 100644
--- a/src/test/scripts/applications/nn/util.dml
+++ b/src/test/scripts/applications/nn/util.dml
@@ -153,3 +153,47 @@ check_rel_grad_error = function(double dw_a, double dw_n,
double lossph, double
}
}
+fail = function(string message) {
+ print("ERROR: " + message)
+}
+
+all_close = function(matrix[double] X1, matrix[double] X2, double epsilon)
+ return (boolean all_pretty_close) {
+ /*
+ * Determine if all values of two matrices are within range of epsilon to
another.
+ *
+ * Inputs:
+ * - X1: Inputs, of shape (any, any).
+ * - X2: Inputs, of same shape as X1.
+ *
+ * Outputs:
+ * - all_pretty_close: Whether or not the values of the two matrices are
all close.
+ */
+ # Determine if matrices are all close
+ all_pretty_close = as.boolean(prod(abs(X1 - X2) <= epsilon))
+}
+
+check_all_close = function(matrix[double] X1, matrix[double] X2, double
epsilon)
+ return (boolean all_pretty_close) {
+ /*
+ * Check if all values of two matrices are within range of epsilon to
another,
+ * and report any issues.
+ *
+ * Issues an "ERROR" statement if elements of the two matrices are
+ * not within range epsilon to another.
+ *
+ * Inputs:
+ * - X1: Inputs, of shape (any, any).
+ * - X2: Inputs, of same shape as X1.
+ *
+ * Outputs:
+ * - all_pretty_close: Whether or not the values of the two matrices are
all close.
+ */
+ # Determine if matrices are all close
+ all_pretty_close = all_close(X1, X2, epsilon)
+
+ # Evaluate relative error
+ if (!all_pretty_close) {
+ print("ERROR: The values of the two matrices are not all close.")
+ }
+}
\ No newline at end of file