Repository: incubator-systemml Updated Branches: refs/heads/master a3e7e5c49 -> d04d2381f
[SYSTEMML-1483] Add Deconvolution layer in nn library. Closes #490. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d04d2381 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d04d2381 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d04d2381 Branch: refs/heads/master Commit: d04d2381f369bc29c4c33e98381bcdc8a4d0aebb Parents: a3e7e5c Author: prithvirajsen <s...@us.ibm.com> Authored: Fri May 12 13:59:20 2017 -0700 Committer: prithvirajsen <s...@us.ibm.com> Committed: Fri May 12 14:04:15 2017 -0700 ---------------------------------------------------------------------- scripts/nn/layers/conv2d_transpose.dml | 248 ++++++++++++++++++++++++++++ scripts/nn/test/grad_check.dml | 109 ++++++++++++ scripts/nn/test/run_tests.dml | 2 + scripts/nn/test/test.dml | 37 +++++ 4 files changed, 396 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d04d2381/scripts/nn/layers/conv2d_transpose.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/conv2d_transpose.dml b/scripts/nn/layers/conv2d_transpose.dml new file mode 100644 index 0000000..ecd7a1c --- /dev/null +++ b/scripts/nn/layers/conv2d_transpose.dml @@ -0,0 +1,248 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * 2D Transpose convolutional layer. + * + * Utilizes built-in convolution operators for higher performance. + */ + +forward = function(matrix[double] X, matrix[double] W, matrix[double] b, + int C, int Hin, int Win, int Hf, int Wf, + int strideh, int stridew, int padh, int padw, + int out_padh, int out_padw) + return (matrix[double] out, int Hout, int Wout){ + /* + * Computes the forward pass for a 2D spatial transpose convolutional + * layer with F filters. The input data has N examples, each + * represented as a 3D tensor flattened into a single vector. + * + * Inputs: + * - X: Inputs, of shape (N, C*Hin*Win). + * - W: Weights, of shape (F, C*Hf*Wf). + * - b: Biases, of shape (F, 1). + * - C: Number of input channels (dimensionality of depth). + * - Hin: Input height. + * - Win: Input width. + * - Hf: Filter height. + * - Wf: Filter width. + * - strideh: Stride over height. + * - stridew: Stride over width. + * - padh: Padding for top and bottom sides. + * - padw: Padding for left and right sides. + * - out_padh: extra padding for top side. This should + * lie in [0, strideh-1]. + * - out_padw: extra padding for right side. This should + * lie in [0, stridew-1]. + * + * Outputs: + * - out: Outputs, of shape (N, F*Hout*Wout). + * - Hout: Output height. + * - Wout: Output width. + */ + N = nrow(X) + F = nrow(W) + Hout = strideh * (Hin-1) - 2*padh + Hf + out_padh + Wout = stridew * (Win-1) - 2*padw + Wf + out_padw + + /* + * Transpose convolution aims to go in the other direction of + * (direct) convolution, i.e., given input X, produce output O such + * that running convolution on O recovers X. This is achieved by + * conv2d_backward_data (since the derivative wrt data must produce + * output of same size as the input to conv2d). By reusing a built-in + * operator we achieve efficiency and restrict the number of built-in + * operators to manageable levels. Plus, most other deep-learning + * packages make use of the same strategy which means this + * implementation of transpose convolution is 'in-sync' with them. + * + * One potential downside of reusing conv2d_backward_data is the fact + * that it rotates the filter by 180 degrees before applying it. This + * needs to be kept in mind when interpreting the output of transpose + * convolution. + */ + out = conv2d_backward_data(W, X, stride=[strideh,stridew], padding=[padh,padw], + input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf]) + + out = bias_add(out, b) +} + +backward = function(matrix[double] dout, int Hout, int Wout, + matrix[double] X, matrix[double] W, matrix[double] b, + int C, int Hin, int Win, int Hf, int Wf, + int strideh, int stridew, int padh, int padw) + return (matrix[double] dX, matrix[double] dW, matrix[double] db){ + /* + * Computes the backward pass for a 2D spatial transpose + * convolutional layer with F filters. + * + * Inputs: + * - dout: Gradient wrt `out` from upstream, of + * shape (N, F*Hout*Wout). + * - Hout: Output height. + * - Wout: Output width. + * - X: Inputs, of shape (N, C*Hin*Win). + * - W: Weights, of shape (F, C*Hf*Wf). + * - b: Biases, of shape (F, 1). + * - C: Number of input channels (dimensionality of depth). + * - Hin: Input height. + * - Win: Input width. + * - Hf: Filter height. + * - Wf: Filter width. + * - strideh: Stride over height. + * - stridew: Stride over width. + * - padh: Padding for top and bottom sides. + * - padw: Padding for left and right sides. + * + * Outputs: + * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win). + * - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf). + * - db: Gradient wrt `b`, of shape (F, 1). + */ + N = nrow(X) + F = nrow(W) + + /* + * conv2d_backward_filter takes the input and delta map as first and + * second args, respectively. Given that, we need to compute the + * grad (wrt to filter) for transpose convolution where the roles of + * the input and output are reversed, we reverse the order of the + * args (along with setting input_shape to the delta map shape). + * Effectively, we are running a direct convolution with X as the + * filter and the dout as the input. To convince oneself that the + * interconnections between the cells of the filter, input and delta + * map are preserved please keep in mind that the forward of + * convolution transpose rotates the filter by 180 degrees before + * applying it. + */ + dW = conv2d_backward_filter(dout, X, stride=[strideh,stridew], padding=[padh,padw], + input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf]) + + /* + * Since the forward for transpose convolution makes a call to + * conv2d_backward_data, to compute its derivative wrt to data + * we can run conv2d by applying the filter on the delta + * map (this makes sense because convolution transpose is the + * 'reverse' of convolution). Its easy to see that this will produce + * output of the required size. To convince oneself that conv2d will + * respect the interconnections between the cells in the delta map + * and the filter, keep in mind that the forward function rotates the + * filter by 180 degrees before applying it. + */ + dX = conv2d(dout, W, input_shape=[N,C,Hout,Wout], filter_shape=[F,C,Hf,Wf], + stride=[strideh,stridew], padding=[padh,padw]) + + db = rowSums(matrix(colSums(dout), rows=F, cols=Hout*Wout)) +} + +init = function(int F, int C, int Hf, int Wf) + return (matrix[double] W, matrix[double] b){ + /* + * Utility function to initialize the parameters of this layer. + * + * We use the heuristic by He et al., which limits the magnification + * of inputs/gradients during forward/backward passes by scaling + * unit-Gaussian weights by a factor of sqrt(2/n), under the + * assumption of relu neurons. + * - http://arxiv.org/abs/1502.01852 + * + * Inputs: + * - F: Number of filters. + * - C: Number of input channels (dimensionality of depth). + * - Hf: Filter height. + * - Wf: Filter width. + * + * Outputs: + * - W: Weights, of shape (F, C*Hf*Wf). + * - b: Biases, of shape (F, 1). + */ + W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2/(C*Hf*Wf)) + b = matrix(0, rows=F, cols=1) +} + +init_bilinear = function(int C, int K) + return (matrix[double] W, matrix[double] b){ + /* + * Utility function to upsample using this layer. + * + * Upsampling the input by factor f (each side) requires + * channel-wise independent kernels of size K = 2f - f%2, + * stride = f and pad = ceil((f-1)/2). The weights are set + * via bilinear interpolation, bias is set to 0. + * + * Inputs: + * - C: Number of input channels (dimensionality of depth). + * - K: Kernel size (upsampling requires a square filter + * of size K X K). + * + * Outputs: + * - W: Weights, of shape (C, C*K*K). + * - b: Biases, of shape (C, 1). + */ + factor_up = ceil(K / 2) + center = (2 * factor_up - factor_up %% 2 - 1) / 2 / factor_up + vect = 1 - abs(seq(0, K-1) / factor_up - center) + weights = matrix(vect %*% t(vect), rows=1, cols=K*K) + + /* + * To create a multi-channel channel-independent upsampling filter, + * we need to intersperse the filter weights with 0s. For instance, + * consider the case of 2X upsampling. In this case, K=4 and we have + * K^2=16 weights to include into the 3D tensor representing the + * filter which should look like the following (assuming 3 channels): + * + * <-16 weights-> <---------32 0s---------> + * X X ...... X X 0 0 0 ............. 0 0 0 + * 0 .......... 0 X X .... X X 0 ...... 0 0 + * 0 0 0 ............... 0 0 0 X X .... X X + * + * To be clear, the second row should have 16 0s followed by 16 + * weights followed by 16 0s. + * + * To create the above filter, we take advantage of the fact that + * between two sets of non-zero weights, there is always a sequence + * of C*K*K 0s. In the above example, C*K^2 = 48 (e.g., 32 trailing + * 0s in the first row and 16 leading 0s in the second row). + * + * Note that, in the special case of C=1 we do not need to + * intersperse with 0s (no question of being channel-wise independent + * since we have only 1 channel). + */ + #if(C > 1){ + /* + * Append C*K*K trailing 0s to the K*K kernel and replicate the + * resulting row C times + */ + repl_weights = matrix(1, rows=C, cols=1) %*% cbind(weights, matrix(0, rows=1, cols=C*K*K)) + + /* + * The above operation added extra C*K*K trailing 0s in the last row + * that we do not need. Thus, we need to: + * 1) reshape the resulting matrix into a row + * 2) 'Clip off' the last few 0s using indexing and reshape the + * result into the expected filter shape ([C, C, K, K]) + */ + repl_weights_row = matrix(repl_weights, rows=1, cols=C*(C+1)*K^2) + W = matrix(repl_weights_row[1,1:(C*K)^2], rows=C, cols=C*K^2) + #}else W = weights + + b = matrix(0, rows=C, cols=1) +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d04d2381/scripts/nn/test/grad_check.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml index e767eee..5e57081 100644 --- a/scripts/nn/test/grad_check.dml +++ b/scripts/nn/test/grad_check.dml @@ -27,6 +27,7 @@ source("nn/layers/batch_norm1d.dml") as batch_norm1d source("nn/layers/batch_norm2d.dml") as batch_norm2d source("nn/layers/conv2d.dml") as conv2d source("nn/layers/conv2d_builtin.dml") as conv2d_builtin +source("nn/layers/conv2d_transpose.dml") as conv2d_transpose source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss source("nn/layers/dropout.dml") as dropout source("nn/layers/l1_loss.dml") as l1_loss @@ -616,6 +617,114 @@ conv2d_simple = function() { } } +conv2d_transpose = function() { + /* + * Gradient check for the 2D convolution transpose layer. + */ + print("Grad checking the 2D convolution transpose layer with L2 loss.") + + N = 2 + C = 2 + Hin = 3 + Win = 3 + F = 2 + Hf = 3 + Wf = 3 + stride = 2 + pad = 1 + out_pad = 1 + + X = rand(rows=N, cols=C*Hin*Win) + + [W,b] = conv2d_transpose::init(F, C, Hf, Wf) + + [out, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + y = rand(rows=N, cols=F*Hout*Wout) + + dout = l2_loss::backward(out,y) + + [dX, dW, db] = conv2d_transpose::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, + stride, stride, pad, pad) + + h = 1e-5 + print(" - Grad checking X.") + for (i in 1:nrow(X)) { + for (j in 1:ncol(X)) { + old = as.scalar(X[i,j]) + X[i,j] = old - h + + [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossmh = l2_loss::forward(outmh, y) + + X[i,j] = old + h + [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossph = l2_loss::forward(outph, y) + + X[i,j] = old + + dX_num = (lossph-lossmh) / (2*h) + + rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh) + } + } + + print(" - Grad checking W.") + for (i in 1:nrow(W)) { + for (j in 1:ncol(W)) { + old = as.scalar(W[i,j]) + W[i,j] = old - h + + [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossmh = l2_loss::forward(outmh, y) + + W[i,j] = old + h + [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossph = l2_loss::forward(outph, y) + + W[i,j] = old + + dW_num = (lossph-lossmh) / (2*h) + + rel_error = test_util::check_rel_grad_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh) + } + } + + print(" - Grad checking b.") + for (i in 1:nrow(b)) { + for (j in 1:ncol(b)) { + old = as.scalar(b[i,j]) + b[i,j] = old - h + + [outmh, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossmh = l2_loss::forward(outmh, y) + + b[i,j] = old + h + [outph, Hout, Wout] = conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad, out_pad, out_pad) + + lossph = l2_loss::forward(outph, y) + + b[i,j] = old + + db_num = (lossph-lossmh) / (2*h) + + rel_error = test_util::check_rel_grad_error(as.scalar(db[i,j]), db_num, lossph, lossmh) + } + } +} + cross_entropy_loss = function() { /* * Gradient check for the cross-entropy loss function. http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d04d2381/scripts/nn/test/run_tests.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml index b061d26..b48606c 100644 --- a/scripts/nn/test/run_tests.dml +++ b/scripts/nn/test/run_tests.dml @@ -45,6 +45,7 @@ grad_check::batch_norm2d() grad_check::conv2d() grad_check::conv2d_builtin() grad_check::conv2d_simple() +grad_check::conv2d_transpose() grad_check::dropout() grad_check::lstm() grad_check::max_pool2d() @@ -85,6 +86,7 @@ print("---") test::batch_norm1d() test::batch_norm2d() test::conv2d() +test::conv2d_transpose() test::cross_entropy_loss() test::im2col() test::max_pool2d() http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d04d2381/scripts/nn/test/test.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml index a5cb497..0df7f0f 100644 --- a/scripts/nn/test/test.dml +++ b/scripts/nn/test/test.dml @@ -26,6 +26,7 @@ source("nn/layers/batch_norm1d.dml") as batch_norm1d source("nn/layers/batch_norm2d.dml") as batch_norm2d source("nn/layers/conv2d.dml") as conv2d source("nn/layers/conv2d_builtin.dml") as conv2d_builtin +source("nn/layers/conv2d_transpose.dml") as conv2d_transpose source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss source("nn/layers/max_pool2d.dml") as max_pool2d source("nn/layers/max_pool2d_builtin.dml") as max_pool2d_builtin @@ -108,6 +109,42 @@ conv2d = function() { } } +conv2d_transpose = function() { + /* + * Test for the 2D convolution transpose function. + */ + print("Testing the 2D convolution transpose function.") + + N = 1 + C = 1 + Hin = 2 + Win = 2 + F = 1 + Hf = 3 + Wf = 3 + stride = 1 + pad = 0 + out_pad = 0 + + X = matrix(seq(1,N*C*Hin*Win), rows=N, cols=C*Hin*Win) + W = matrix(seq(1,F*C*Hf*Wf), rows=F, cols=C*Hf*Wf) + b = matrix(0, rows=F, cols=1) + + # Forward + [out, Hout, Wout] = + conv2d_transpose::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad, out_pad, out_pad) + + # Equivalency check + target = matrix("1 4 7 6 7 23 33 24 19 53 63 42 21 52 59 36", rows=N, cols=C*Hout*Wout) + + for (i in 1:nrow(out)) { + for(j in 1:ncol(out)) { + rel_error = test_util::check_rel_error(as.scalar(out[1,i]), + as.scalar(target[1,i]), 1e-3, 1e-4) + } + } +} + cross_entropy_loss = function() { /* * Test for the cross-entropy loss function.