Repository: incubator-systemml Updated Branches: refs/heads/master 97da0004f -> 16e990928
[SYSTEMML-1428] Fixed maxpooling functions for padding > 0 Closes #437. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/16e99092 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/16e99092 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/16e99092 Branch: refs/heads/master Commit: 16e990928fa0201132688a8f7476856a02253030 Parents: 97da000 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Wed Mar 22 15:25:52 2017 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Wed Mar 22 16:25:52 2017 -0700 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 78 +++++++++++++++----- .../apache/sysml/yarn/ropt/ResourceConfig.java | 2 +- 2 files changed, 60 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/16e99092/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 2547a87..5ab41e0 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -427,6 +427,12 @@ public class LibMatrixDNN { outputBlock.recomputeNonZeros(); } + /** + * This method computes start and end indexes required for max_pool and max_pool_backward operations. + * This speeds up the performance of max_pool and max_pool_backward + * + * @param params parameters required for max_pool and max_pool_backward operations + */ private static void fillIndexesArray(ConvolutionParameters params) { params.start_indexes_h = new int[params.P]; params.end_indexes_h = new int[params.P]; @@ -434,17 +440,17 @@ public class LibMatrixDNN { params.end_indexes_w = new int[params.Q]; for (int p = 0; p < params.P; p++) { int start_index_h = p * params.stride_h - params.pad_h; - final int end_index_h = Math.min(start_index_h + params.R, params.H); - start_index_h = Math.max(start_index_h, 0); - params.start_indexes_h[p] = start_index_h; - params.end_indexes_h[p] = end_index_h; + int end_index_h = start_index_h + params.R; + // Note: We do not treat pad as zero + params.start_indexes_h[p] = Math.max(start_index_h, 0); + params.end_indexes_h[p] = Math.min(end_index_h, params.H); } for (int q = 0; q < params.Q; q++) { - int start_index_w = Math.max(q * params.stride_w - params.pad_w, 0); - int end_index_w = Math.min(start_index_w + params.S, params.W); - start_index_w = Math.max(start_index_w, 0); - params.start_indexes_w[q] = start_index_w; - params.end_indexes_w[q] = end_index_w; + int start_index_w = q * params.stride_w - params.pad_w; + int end_index_w = start_index_w + params.S; + // Note: We do not treat pad as zero + params.start_indexes_w[q] = Math.max(start_index_w, 0); + params.end_indexes_w[q] = Math.min(end_index_w, params.W); } } @@ -486,7 +492,8 @@ public class LibMatrixDNN { if(inVal != 0) { final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W; int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params); - outputArray[maxIndex] += inVal; + if(maxIndex != -1) + outputArray[maxIndex] += inVal; } } } @@ -510,7 +517,8 @@ public class LibMatrixDNN { final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W; int maxIndex = getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params); - outputArray[maxIndex] += ijv.getV(); + if(maxIndex != -1) + outputArray[maxIndex] += ijv.getV(); } } @@ -530,7 +538,8 @@ public class LibMatrixDNN { final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W; int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params); - outputArray[maxIndex] += ijv.getV(); + if(maxIndex != -1) + outputArray[maxIndex] += ijv.getV(); } } @@ -543,12 +552,26 @@ public class LibMatrixDNN { for (int p = 0; p < params.P; p++) { for (int q = 0; q < params.Q; q++) { int maxIndex = getMaxIndex(p, q, inputOffset, inputArray, params); - outputArray[maxIndex] += doutArray[outputOffset + p * params.Q + q]; + if(maxIndex != -1) + outputArray[maxIndex] += doutArray[outputOffset + p * params.Q + q]; } } } } + /** + * Returns the index of cell with maximum value. This method is optimized for sparse input + * + * @param p output feature map height + * @param q output feature map width + * @param inputOffset offset to be used for input index + * @param n number of images + * @param c number of channels + * @param input input matrix + * @param params convolution parameters + * @return index of the cell with maximum value + * @throws DMLRuntimeException if error occurs + */ private static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException { if(!input.isInSparseFormat()) throw new DMLRuntimeException("Incorrect usage: Only sparse format supported"); @@ -562,8 +585,12 @@ public class LibMatrixDNN { int start_index_w = params.start_indexes_w[q]; int end_index_w = params.end_indexes_w[q]; - int maxIndex = inputOffset + start_index_h*params.W + start_index_w; + int maxIndex = -1; double maxVal = -Double.MAX_VALUE; + + // Note: We do not treat pad as zero and hence we don't do: + // maxVal = 0 + // if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W // Find maxIndex double currDoutVal = -1; @@ -585,15 +612,29 @@ public class LibMatrixDNN { return maxIndex; } + /** + * Returns the index of cell with maximum value. This method is optimized for dense input + * + * @param p output feature map height + * @param q output feature map width + * @param inputOffset offset to be used for input index + * @param inputArray input array + * @param params convolution parameters + * @return index of cell with maximum value + */ private static int getMaxIndex(int p, int q, int inputOffset, double [] inputArray, ConvolutionParameters params) { int start_index_h = params.start_indexes_h[p]; int end_index_h = params.end_indexes_h[p]; int start_index_w = params.start_indexes_w[q]; int end_index_w = params.end_indexes_w[q]; - int maxIndex = inputOffset + start_index_h*params.W + start_index_w; + int maxIndex = -1; double maxVal = -Double.MAX_VALUE; - + + // Note: We do not treat pad as zero and hence we don't do: + // maxVal = 0 + // if start_index_h < 0 || start_index_w < 0 || end_index_h >= params.H || end_index_w >= params.W + // Find maxIndex double currDoutVal = -1; for (int h = start_index_h; h < end_index_h; h++) { @@ -899,7 +940,7 @@ public class LibMatrixDNN { //post-processing: maintain nnz outputBlock.recomputeNonZeros(); } - + private static void doPooling(int n, ConvolutionParameters params) throws DMLRuntimeException { double [] inputArray = null; if (!params.input1.isInSparseFormat()) @@ -936,8 +977,7 @@ public class LibMatrixDNN { for (int q = 0; q < params.Q; q++, out_index++) { for (int h = params.start_indexes_h[p]; h < params.end_indexes_h[p]; h++) { for (int w = params.start_indexes_w[q]; w < params.end_indexes_w[q]; w++) { - double inVal = params.input1.quickGetValue(n, c*HW + h*params.W + w); - outputArray[out_index] = Math.max(outputArray[out_index], inVal); + outputArray[out_index] = Math.max(outputArray[out_index], params.input1.quickGetValue(n, c*HW + h*params.W + w)); } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/16e99092/src/main/java/org/apache/sysml/yarn/ropt/ResourceConfig.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/yarn/ropt/ResourceConfig.java b/src/main/java/org/apache/sysml/yarn/ropt/ResourceConfig.java index 45ad782..bdbf053 100644 --- a/src/main/java/org/apache/sysml/yarn/ropt/ResourceConfig.java +++ b/src/main/java/org/apache/sysml/yarn/ropt/ResourceConfig.java @@ -101,7 +101,7 @@ public class ResourceConfig public long getMaxMRResource() { - double val = Collections.max(_mrres); + double val = (double) Collections.max(_mrres); return (long)val; }