Repository: systemml
Updated Branches:
  refs/heads/master bd1946a3d -> 4744da167


[SYSTEMML-540] Support sparse output for CP im2col operator

- We support sparse output for im2col, which improved the performance by
  4x on CPU for 1-layer CNN with input sparsity of 0.04.
- This commit also fixes the NeuralNetworkOpTests.

Closes #645.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/4744da16
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/4744da16
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/4744da16

Branch: refs/heads/master
Commit: 4744da1670eb6e8c922394f563f360c0fa233482
Parents: bd1946a
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Wed Aug 30 10:17:04 2017 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Wed Aug 30 10:23:39 2017 -0700

----------------------------------------------------------------------
 .../instructions/gpu/context/CSRPointer.java    |   8 +-
 .../LibMatrixDNNConv2dBackwardFilterHelper.java |   1 -
 .../matrix/data/LibMatrixDNNConv2dHelper.java   |   2 -
 .../matrix/data/LibMatrixDNNIm2ColHelper.java   | 278 ++++---
 .../org/apache/sysml/test/gpu/GPUTests.java     |  58 +-
 .../sysml/test/gpu/NeuralNetworkOpTests.java    | 733 +++++++++----------
 6 files changed, 592 insertions(+), 488 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index a4147a3..9379534 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -185,13 +185,19 @@ public class CSRPointer {
         * @param rowPtr integer array of row pointers
         * @param colInd integer array of column indices
         * @param values double array of non zero values
+        * @throws DMLRuntimeException if error occurs
         */
-       public static void copyToDevice(CSRPointer dest, int rows, long nnz, 
int[] rowPtr, int[] colInd, double[] values) {
+       public static void copyToDevice(CSRPointer dest, int rows, long nnz, 
int[] rowPtr, int[] colInd, double[] values) throws DMLRuntimeException {
                CSRPointer r = dest;
                long t0 = 0;
                if (DMLScript.STATISTICS)
                        t0 = System.nanoTime();
                r.nnz = nnz;
+               if(rows < 0) throw new DMLRuntimeException("Incorrect input 
parameter: rows=" + rows);
+               if(nnz < 0) throw new DMLRuntimeException("Incorrect input 
parameter: nnz=" + nnz);
+               if(rowPtr.length < rows + 1) throw new DMLRuntimeException("The 
length of rowPtr needs to be greater than or equal to " + (rows + 1));
+               if(colInd.length < nnz) throw new DMLRuntimeException("The 
length of colInd needs to be greater than or equal to " + nnz);
+               if(values.length < nnz) throw new DMLRuntimeException("The 
length of values needs to be greater than or equal to " + nnz);
                cudaMemcpy(r.rowPtr, Pointer.to(rowPtr), getIntSizeOf(rows + 
1), cudaMemcpyHostToDevice);
                cudaMemcpy(r.colInd, Pointer.to(colInd), getIntSizeOf(nnz), 
cudaMemcpyHostToDevice);
                cudaMemcpy(r.val, Pointer.to(values), getDoubleSizeOf(nnz), 
cudaMemcpyHostToDevice);

http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
index 560f32c..a135f62 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dBackwardFilterHelper.java
@@ -84,7 +84,6 @@ public class LibMatrixDNNConv2dBackwardFilterHelper {
                        int PQ = _params.P*_params.Q; int K = _params.K; int 
CRS = _params.C*_params.R*_params.S;
                        MatrixBlock dout = _params.input2;
                        MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, 
false);
-                       im2ColOutBlock.allocateDenseBlock();
                        MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, 
false);
                        dout_reshaped.allocateDenseBlock();
                        LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = 
LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, 
im2ColOutBlock, _params, true);

http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
index b2c4d67..4c3a3c3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
@@ -48,7 +48,6 @@ public class LibMatrixDNNConv2dHelper {
                        int PQ = _params.P*_params.Q; int K = _params.K;
                        int RS = _params.R*_params.S;
                        MatrixBlock im2ColOutBlock = new MatrixBlock(RS, PQ, 
false);
-                       im2ColOutBlock.allocateDenseBlock();
                        LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = 
LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, 
im2ColOutBlock, _params, false);
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++)  {
@@ -129,7 +128,6 @@ public class LibMatrixDNNConv2dHelper {
                public Long call() throws Exception {
                        int PQ = _params.P*_params.Q; int K = _params.K; int 
CRS = _params.C*_params.R*_params.S;
                        MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, 
false);
-                       im2ColOutBlock.allocateDenseBlock();
                        LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = 
LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker( _params.input1, 
im2ColOutBlock, _params, true);
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++)  {

http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
index 9ae39bf..d427a26 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
@@ -20,38 +20,73 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.util.Arrays;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
 /**
  * This class contains the different implementation of im2col operation
  */
 public class LibMatrixDNNIm2ColHelper {
-       
+       private static final Log LOG = 
LogFactory.getLog(LibMatrixDNNIm2ColHelper.class.getName());
        static interface Im2colWorker {
                public void execute(int n);
                public void execute(int n, int c);
                public static Im2colWorker getWorker(MatrixBlock input, 
MatrixBlock im2ColOutBlock, ConvolutionParameters params, boolean allChannels) {
-                       if(im2ColOutBlock.isInSparseFormat() || 
im2ColOutBlock.getDenseBlock() == null)
-                               throw new RuntimeException("im2col output is 
always in dense format");
                        if(allChannels) {
                                if(!input.isInSparseFormat()) {
-                                       if (params.stride_h == 1 && 
params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) 
+                                       // Note: Only dense im2col operators 
require the im2ColOutBlock to be allocated in the dense format.
+                                       im2ColOutBlock.allocateDenseBlock();
+                                       if (params.stride_h == 1 && 
params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0)  {
+                                               if(LOG.isTraceEnabled()) 
LOG.trace("Using DenseIm2colWorkerStride1Pad0AllChannels operator to perform 
im2col.");
                                                return new 
DenseIm2colWorkerStride1Pad0AllChannels(input.getDenseBlock(), 
im2ColOutBlock.getDenseBlock(), params);
-                                       else
+                                       }
+                                       else {
+                                               if(LOG.isTraceEnabled()) 
LOG.trace("Using DenseIm2colWorkerAllChannels operator to perform im2col.");
                                                return new 
DenseIm2colWorkerAllChannels(input.getDenseBlock(), 
im2ColOutBlock.getDenseBlock(), params);
+                                       }
+                               }
+                               else {
+                                       if(LOG.isTraceEnabled()) 
LOG.trace("Using SparseIm2colWorkerAllChannels operator to perform im2col.");
+                                       double sparsity = 
Math.min(MatrixBlock.SPARSITY_TURN_POINT, (input.getNonZeros()*2.0) / 
(input.getNumRows()*input.getNumColumns()));
+                                       
initializeSparseIm2ColBlock(im2ColOutBlock, 
(long)Math.ceil(params.P*params.Q*sparsity));
+                                       return new 
SparseSparseIm2colWorkerAllChannels(input, im2ColOutBlock, params);
                                }
-                               else 
-                                       return new 
SparseIm2colWorkerAllChannels(input, im2ColOutBlock, params);
                        }
                        else {
                                if(!input.isInSparseFormat()) {
-                                       if (params.stride_h == 1 && 
params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) 
+                                       // Note: Only dense im2col operators 
require the im2ColOutBlock to be allocated in the dense format.
+                                       im2ColOutBlock.allocateDenseBlock();
+                                       if (params.stride_h == 1 && 
params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) {
+                                               if(LOG.isTraceEnabled()) 
LOG.trace("Using DenseIm2colWorkerStride1Pad0 operator to perform im2col.");
                                                return new 
DenseIm2colWorkerStride1Pad0(input.getDenseBlock(), 
im2ColOutBlock.getDenseBlock(), params);
-                                       else
+                                       }
+                                       else {
+                                               if(LOG.isTraceEnabled()) 
LOG.trace("Using DenseIm2colWorker operator to perform im2col.");
                                                return new 
DenseIm2colWorker(input.getDenseBlock(), im2ColOutBlock.getDenseBlock(), 
params);
+                                       }
+                               }
+                               else {
+                                       if(LOG.isTraceEnabled()) 
LOG.trace("Using SparseIm2colWorker operator to perform im2col.");
+                                       double sparsity = 
Math.min(MatrixBlock.SPARSITY_TURN_POINT, (input.getNonZeros()*2.0) / 
(input.getNumRows()*input.getNumColumns()));
+                                       
initializeSparseIm2ColBlock(im2ColOutBlock, 
(long)Math.ceil(params.P*params.Q*sparsity));
+                                       return new 
SparseSparseIm2colWorker(input, im2ColOutBlock, params);
                                }
-                               else 
-                                       return new SparseIm2colWorker(input, 
im2ColOutBlock, params);
                        }
                }
+               
+               static void initializeSparseIm2ColBlock(MatrixBlock 
im2ColOutBlock, long worstCaseNNPerRow) {
+                       if(worstCaseNNPerRow >= Integer.MAX_VALUE)
+                               throw new RuntimeException("The dimension of 
intermediate im2col matrix exceeded:" + worstCaseNNPerRow);
+                       // Set to sparse
+                       im2ColOutBlock.sparse = true;
+                       im2ColOutBlock.denseBlock = null;
+                       im2ColOutBlock.allocateSparseRowsBlock();
+                       
+                       for(int r = 0; r < im2ColOutBlock.getNumRows(); r++) {
+                               im2ColOutBlock.getSparseBlock().allocate(r, 
(int) worstCaseNNPerRow);
+                       }
+                       im2ColOutBlock.setNonZeros(0);
+               }
        }
        
        /**
@@ -238,20 +273,23 @@ public class LibMatrixDNNIm2ColHelper {
        }
        
        /**
-        * Performing dense im2col (general case)
+        * Performing sparse im2col for all channels for a given row n of the 
input matrix.
         */
-       static class SparseIm2colWorkerAllChannels implements Im2colWorker {
-               MatrixBlock input; double [] outputArray; 
-               int CRS; int S; int R; int P; int Q; int H; int W; 
-               int stride_h; int stride_w; int pad_h; int pad_w; double [] 
temp;
-               public SparseIm2colWorkerAllChannels(MatrixBlock input, 
MatrixBlock im2ColOutBlock, ConvolutionParameters params) {
+       static class SparseSparseIm2colWorkerAllChannels implements 
Im2colWorker {
+               MatrixBlock input;  MatrixBlock output;
+               int CRS; int S; int R; int P; int Q; int H; int W; int RS; int 
HW;
+               int stride_h; int stride_w; int pad_h; int pad_w;
+               public SparseSparseIm2colWorkerAllChannels(MatrixBlock input, 
MatrixBlock im2ColOutBlock, ConvolutionParameters params) {
                        this.input = input;
-                       this.outputArray = im2ColOutBlock.getDenseBlock();
+                       this.output = im2ColOutBlock;
                        this.CRS = params.C * params.R * params.S;
+                       this.RS = params.R * params.S;
+                       this.HW = params.H * params.W;
                        this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
                        this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
                        this.pad_h = params.pad_h; this.pad_w = params.pad_w;
-                       temp = new double[input.getNumColumns()];
+                       if(!input.isInSparseFormat()) 
+                               throw new RuntimeException("Incorrect operator 
selection. Expected dense input for SparseIm2colWorkerAllChannels");
                }
                
                @Override
@@ -261,71 +299,107 @@ public class LibMatrixDNNIm2ColHelper {
 
                @Override
                public void execute(int n) {
-                       // Using a temporary array improves performance by not 
requiring binary search for getValue
-                       // Since the access pattern depends on 
ConvolutionParameters, this serves as a temporary fix.
-                       fillTemp(input, n);
-                       // final int nOffset = n * params.C*params.H*params.W;
-                       for (int c = 0; c < CRS; ++c) {
-                               int wOffset = c % S;
-                               int hOffset = (c / S) % R;
-                               int cInput = c / R / S;
-                               for (int h = 0; h < P; ++h) {
-                                       int outOffset = (c * P + h) * Q;
-                                       int hPadded = h * stride_h - pad_h + 
hOffset;
-                                       int tempOffset = (cInput * H + hPadded) 
* W;
-                                       // int inputOffset = nOffset + 
tempOffset;
-                                       if (hPadded < 0 || hPadded >= H) {
-                                               Arrays.fill(outputArray, 
outOffset, outOffset+Q, 0);
-                                       } else {
-                                               for (int w = 0; w < Q; ++w) {
-                                                       int wPadded = w * 
stride_w - pad_w + wOffset;
-                                                       if (wPadded >= 0 && 
wPadded < W) 
-                                                               
outputArray[outOffset + w] = temp[tempOffset + wPadded];
-                                                       else
-                                                               
outputArray[outOffset + w] = 0;
-                                               }
-                                       }
-                               }
-                       }
-               }
-               // Returns the row of matrix in dense format
-               private void fillTemp(MatrixBlock input, int n) {
-                       if(input.getNumColumns() != temp.length) {
-                               throw new RuntimeException("Invalid 
parameters");
-                       }
-                       // Use temporary array to avoid binary search
-                       if(input.isInSparseFormat()) {
-                               Arrays.fill(temp, 0);
-                               if( !input.sparseBlock.isEmpty(n) ) {
-                                       int apos = input.sparseBlock.pos(n);
-                                       int alen = input.sparseBlock.size(n);
-                                       int[] aix = 
input.sparseBlock.indexes(n);
-                                       double[] avals = 
input.sparseBlock.values(n);
-                                       for(int j=apos; j<apos+alen; j++)
-                                               temp[ aix[j] ] = avals[j];
+                       if( !input.sparseBlock.isEmpty(n) ) {
+                               output.sparseBlock.reset();
+                               output.setNonZeros(0);
+                               int apos = input.sparseBlock.pos(n);
+                               int alen = input.sparseBlock.size(n);
+                               int[] aix = input.sparseBlock.indexes(n);
+                               double[] avals = input.sparseBlock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       // Note: the input is of shape [N, CHW]
+                                       int chw = aix[j];
+                                       
+                                       // Get individual zero-based c,h,w 
indexes from zero-based 'chw'
+                                       int cInput = chw / HW;
+                                       int hInput = (chw - cInput*HW)/W;
+                                       int wInput = chw % W; 
+                                       
+                                       appendInputValueToIm2colOutput(output, 
cInput, hInput, wInput, avals[j], 
+                                                       R, S, P, Q, stride_h, 
stride_w, pad_h, pad_w);
                                }
+                               // Since the chw are appended in sorted order, 
no need to sort the output rows
+                               // if(meta.sortRows) output.sortSparseRows();
                        }
                        else {
-                               System.arraycopy(input.getDenseBlock(), 
n*input.getNumColumns(), temp, 0, input.getNumColumns());
+                               output.setNonZeros(0);
                        }
                }
        }
        
        /**
-        * Performing dense im2col (general case)
+        * Appends the value corresponding to the given [, cInput, hInput, 
wInput] to the appropriate im2col location of the output
+        * 
+        * @param output output matrix block
+        * @param cInput input channel index (zero-based)
+        * @param hInput input height index (zero-based)
+        * @param wInput input width index (zero-based)
+        * @param value input value
+        * @param R filter height
+        * @param S filter width
+        * @param P output height
+        * @param Q output width
+        * @param stride_h stride height
+        * @param stride_w stride width
+        * @param pad_h pad height
+        * @param pad_w pad width
+        */
+       private static void appendInputValueToIm2colOutput(MatrixBlock output, 
int cInput, int hInput, int wInput, double value, 
+                       int R, int S, int P, int Q, int stride_h, int stride_w, 
int pad_h, int pad_w) {
+               if(value == 0) 
+                       return;
+               int RS = R*S;
+               // For the given h,w index, insert avals[j] into respective 
r,s,p,q locations
+               
+               // Constraints: for(int r = 0; r < R; r++) { if(0 <= p && p < P 
&& (hInput - r + pad_h) % stride_h == 0) { ... } }
+               // Constraint 1: p >= 0 and p = (hInput - r + pad_h)  / stride_h
+               // Therefore,  r <= hInput + pad_h 
+               // Constraint 2: p < P and p = (hInput - r + pad_h)  / stride_h
+               // Therefore,  hInput + pad_h - P*stride_h < r
+               // Math.max(0, hInput + pad_h - P*stride_h + 1) <= r <= 
Math.min(R-1, hInput + pad_h)
+               int rMin = Math.max(0, hInput + pad_h - P*stride_h + 1);
+               int rMax = Math.min(R-1, hInput + pad_h);
+               int sMin = Math.max(0, wInput + pad_w - Q*stride_w + 1);
+               int sMax = Math.min(S-1, wInput + pad_w);
+               // Constraint 3: (hInput - r + pad_h) % stride_h == 0
+               while((hInput - rMin + pad_h) % stride_h != 0 && rMin <= rMax) 
rMin++;
+               while((wInput - sMin + pad_w) % stride_w != 0 && sMin <= sMax) 
sMin++;  
+               
+               for(int r = rMin; r <= rMax; r += stride_h) {
+                       // Only append value if h == hInput, where h = (r - 
pad_h) + p*stride_h and 0 <= p < P
+                       // Therefore, p = (hInput - r + pad_h)  / stride_h. Use 
the same logic for q.
+                       final int p = (hInput - r + pad_h)  / stride_h;
+                       final int pQ = p*Q;
+                       final int outRowIndex = cInput*RS + r*S;
+                       for(int s = sMin; s <= sMax; s += stride_w) {
+                               int q = (wInput - s + pad_w)  / stride_w;
+                               // chw -> [crs, pq]
+                               output.appendValue(outRowIndex + s, pQ + q, 
value);
+                               // Since the chw are appended in sorted order, 
no need to sort the output rows
+                               // if(meta.lastIndexPerRow[outRowIndex + s] > 
p*Q + q) meta.sortRows = true;
+                               // meta.lastIndexPerRow[outRowIndex + s] = p*Q 
+ q;
+                       }
+               }
+       }
+       
+       /**
+        * Performing sparse im2col for a given channel c and for a given row n 
of the input matrix.
         */
-       static class SparseIm2colWorker implements Im2colWorker {
-               MatrixBlock input; double [] outputArray; 
-               int CRS; int S; int R; int P; int Q; int H; int W; 
-               int stride_h; int stride_w; int pad_h; int pad_w; double [] 
temp;
-               public SparseIm2colWorker(MatrixBlock input, MatrixBlock 
im2ColOutBlock, ConvolutionParameters params) {
+       static class SparseSparseIm2colWorker implements Im2colWorker {
+               MatrixBlock input; MatrixBlock output;
+               int CRS; int S; int R; int P; int Q; int H; int W; int HW; int 
RS;
+               int stride_h; int stride_w; int pad_h; int pad_w; 
+               public SparseSparseIm2colWorker(MatrixBlock input, MatrixBlock 
im2ColOutBlock, ConvolutionParameters params) {
                        this.input = input;
-                       this.outputArray = im2ColOutBlock.getDenseBlock();
+                       this.output = im2ColOutBlock;
                        this.CRS = params.C * params.R * params.S;
+                       this.HW = params.H*params.W;
+                       this.RS = params.R*params.S;
                        this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
                        this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
                        this.pad_h = params.pad_h; this.pad_w = params.pad_w;
-                       temp = new double[input.getNumColumns()];
                }
                
                @Override
@@ -335,52 +409,36 @@ public class LibMatrixDNNIm2ColHelper {
 
                @Override
                public void execute(int n, int cInput) {
-                       // Using a temporary array improves performance by not 
requiring binary search for getValue
-                       // Since the access pattern depends on 
ConvolutionParameters, this serves as a temporary fix.
-                       fillTemp(input, n); int RS = R*S;
-                       for (int rs = 0; rs < RS; ++rs) {
-                               int wOffset = rs % S;
-                               int hOffset = rs / S;
-                               for (int h = 0; h < P; ++h) {
-                                       int outOffset = (rs * P + h) * Q;
-                                       int hPadded = h * stride_h - pad_h + 
hOffset;
-                                       int tempOffset = (cInput * H + hPadded) 
* W;
-                                       // int inputOffset = nOffset + 
tempOffset;
-                                       if (hPadded < 0 || hPadded >= H) {
-                                               Arrays.fill(outputArray, 
outOffset, outOffset+Q, 0);
-                                       } else {
-                                               for (int w = 0; w < Q; ++w) {
-                                                       int wPadded = w * 
stride_w - pad_w + wOffset;
-                                                       if (wPadded >= 0 && 
wPadded < W) 
-                                                               
outputArray[outOffset + w] = temp[tempOffset + wPadded];
-                                                       else
-                                                               
outputArray[outOffset + w] = 0;
-                                               }
+                       if( !input.sparseBlock.isEmpty(n) ) {
+                               output.sparseBlock.reset();
+                               output.setNonZeros(0);
+                               int apos = input.sparseBlock.pos(n);
+                               int alen = input.sparseBlock.size(n);
+                               int[] aix = input.sparseBlock.indexes(n);
+                               double[] avals = input.sparseBlock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       // Note: the input is of shape [N, CHW]
+                                       int chw = aix[j];
+                                       
+                                       if(cInput == (chw / HW)) {
+                                               // Get individual zero-based 
c,h,w indexes from zero-based 'chw'
+                                               int hInput = (chw - 
cInput*HW)/W;
+                                               int wInput = chw % W; 
+                                               
+                                               
appendInputValueToIm2colOutput(output, cInput, hInput, wInput, avals[j], 
+                                                               R, S, P, Q, 
stride_h, stride_w, pad_h, pad_w);
                                        }
                                }
-                       }
-               }
-               // Returns the row of matrix in dense format
-               private void fillTemp(MatrixBlock input, int n) {
-                       if(input.getNumColumns() != temp.length) {
-                               throw new RuntimeException("Invalid 
parameters");
-                       }
-                       // Use temporary array to avoid binary search
-                       if(input.isInSparseFormat()) {
-                               Arrays.fill(temp, 0);
-                               if( !input.sparseBlock.isEmpty(n) ) {
-                                       int apos = input.sparseBlock.pos(n);
-                                       int alen = input.sparseBlock.size(n);
-                                       int[] aix = 
input.sparseBlock.indexes(n);
-                                       double[] avals = 
input.sparseBlock.values(n);
-                                       for(int j=apos; j<apos+alen; j++)
-                                               temp[ aix[j] ] = avals[j];
-                               }
+                               // Since the chw are appended in sorted order, 
no need to sort the output rows
+                               // if(meta.sortRows) output.sortSparseRows();
                        }
                        else {
-                               System.arraycopy(input.getDenseBlock(), 
n*input.getNumColumns(), temp, 0, input.getNumColumns());
+                               output.setNonZeros(0);
                        }
                }
+               
        }
 
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java 
b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index d40b7a1..56e0e92 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -49,7 +49,8 @@ public abstract class GPUTests extends AutomatedTestBase {
        protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
        protected static SparkSession spark;
        protected final double THRESHOLD = 1e-9;    // for relative error
-
+       private static final boolean PRINT_MAT_ERROR = false;
+       
        @BeforeClass
        public static void beforeClass() {
                spark = createSystemMLSparkSession("GPUTests", "local");
@@ -139,7 +140,7 @@ public abstract class GPUTests extends AutomatedTestBase {
                genMLC.close();
                return in1;
        }
-
+       
        /**
         * Generates a random input matrix with a given size and sparsity
         *
@@ -152,6 +153,22 @@ public abstract class GPUTests extends AutomatedTestBase {
         * @return a random matrix with given size and sparsity
         */
        protected Matrix generateInputMatrix(SparkSession spark, int m, int n, 
double min, double max, double sparsity, int seed) {
+               return generateInputMatrix(spark, m, n, min, max, sparsity, 
seed, false);
+       }
+
+       /**
+        * Generates a random input matrix with a given size and sparsity
+        *
+        * @param spark    valid instance of {@link SparkSession}
+        * @param m        number of rows
+        * @param n        number of columns
+        * @param min      min for RNG
+        * @param max      max for RNG
+        * @param sparsity sparsity (1 = completely dense, 0 = completely 
sparse)
+        * @param performRounding performs rounding after generation of random 
matrix
+        * @return a random matrix with given size and sparsity
+        */
+       protected Matrix generateInputMatrix(SparkSession spark, int m, int n, 
double min, double max, double sparsity, int seed, boolean performRounding) {
                // Generate a random matrix of size m * n
                MLContext genMLC = new MLContext(spark);
                String scriptStr;
@@ -160,12 +177,47 @@ public abstract class GPUTests extends AutomatedTestBase {
                } else {
                        scriptStr = "in1 = rand(rows=" + m + ", cols=" + n + ", 
sparsity = " + sparsity + ", seed= " + seed
                                        + ", min=" + min + ", max=" + max + ")";
+                       if(performRounding)
+                               scriptStr += "; in1 = round(in1)";
                }
                Script generateScript = 
ScriptFactory.dmlFromString(scriptStr).out("in1");
                Matrix in1 = genMLC.execute(generateScript).getMatrix("in1");
                genMLC.close();
                return in1;
        }
+       
+       private void printMatrixIfNotEqual(MatrixBlock expectedMB, MatrixBlock 
actualMB) {
+               long rows = expectedMB.getNumRows();
+               long cols = expectedMB.getNumColumns();
+               boolean matrixNotEqual = false;
+               for (int i = 0; i < rows && !matrixNotEqual; i++) {
+                       for (int j = 0; j < cols; j++) {
+                               double expectedDouble = 
expectedMB.quickGetValue(i, j);
+                               double actualDouble = actualMB.quickGetValue(i, 
j);
+                               if (expectedDouble != 0.0 && 
!Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
+                                       double relativeError = 
Math.abs((expectedDouble - actualDouble) / expectedDouble);
+                                       if(relativeError >= getTHRESHOLD()) {
+                                               matrixNotEqual = true;
+                                               break;
+                                       }
+                               }
+                       }
+               }
+               if(matrixNotEqual) {
+                       System.out.println("Expected mb != Actual mb. 
Mismatches are as follows:");
+                       for (int i = 0; i < rows; i++) {
+                               for (int j = 0; j < cols; j++) {
+                                       double expectedDouble = 
expectedMB.quickGetValue(i, j);
+                                       double actualDouble = 
actualMB.quickGetValue(i, j);
+                                       if (expectedDouble != 0.0 && 
!Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) 
+                                               System.out.print("(" + i + "," 
+ j  + " : " + expectedDouble + " != " + actualDouble + ") ");
+                               }
+                       }
+                       System.out.println();
+               }
+               else
+                       System.out.println("Expected mb = Actual mb");
+       }
 
        /**
         * Asserts that the values in two matrices are in {@link 
UnaryOpTests#THRESHOLD} of each other
@@ -183,6 +235,8 @@ public abstract class GPUTests extends AutomatedTestBase {
                        Assert.assertEquals(rows, actualMB.getNumRows());
                        Assert.assertEquals(cols, actualMB.getNumColumns());
 
+                       if(PRINT_MAT_ERROR) printMatrixIfNotEqual(expectedMB, 
actualMB);
+                       
                        for (int i = 0; i < rows; i++) {
                                for (int j = 0; j < cols; j++) {
                                        double expectedDouble = 
expectedMB.quickGetValue(i, j);

http://git-wip-us.apache.org/repos/asf/systemml/blob/4744da16/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java 
b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
index c53e803..aba0cae 100644
--- a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
@@ -24,9 +24,6 @@ import java.util.HashMap;
 import java.util.List;
 
 import org.apache.sysml.api.mlcontext.Matrix;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.Assert;
@@ -44,28 +41,32 @@ import org.junit.Test;
  * <code>
  * mvn -Dit.test=org.apache.sysml.test.gpu.NeuralNetworkOpTests verify 
-PgpuTests
  * </code>
+ * 
+ * Note: generateInputMatrix(...) method in this test performs rounding of 
input matrix. This helps
+ * to test the correctness of our operators at logical level, but not the 
precision.
+ * 
  */
 public class NeuralNetworkOpTests extends GPUTests {
 
        private final static String TEST_NAME = "NeuralNetworkOpTests";
        // The MAX_OP_SIZE is to take into consideration the memory available 
on the GPU as well as
        // limits set by cudnn (operands need to be less than 2GB)
-       private static final double MAX_OP_SIZE;
-
-       static {
-               double MAX = 0.5 * 1024 * 1024 * 1024; // 0.5 GB (this HAS to 
be less than 2GB)
-               try {
-                       // Cap the maximum allowed operand size to 1/3rd of the 
usable GPU memory or MAX, whichever is lesser
-                       List<GPUContext> gCtxs = 
GPUContextPool.reserveAllGPUContexts();
-                       long availableMemory = 
gCtxs.get(0).getAvailableMemory();
-                       double averageMemoryPerOperand = availableMemory / 3.0;
-                       MAX_OP_SIZE = Math.min(averageMemoryPerOperand, MAX);
-                       GPUContextPool.freeAllGPUContexts();
-               } catch (DMLRuntimeException e) {
-                       throw new RuntimeException(e);
-               }
-
-       }
+       private static final double MAX_OP_SIZE = 0.5 * 1024 * 1024 * 1024; // 
0.5 GB (this HAS to be less than 2GB)
+
+//     static {
+//             double MAX = 0.5 * 1024 * 1024 * 1024; // 0.5 GB (this HAS to 
be less than 2GB)
+//             try {
+//                     // Cap the maximum allowed operand size to 1/3rd of the 
usable GPU memory or MAX, whichever is lesser
+//                     List<GPUContext> gCtxs = 
GPUContextPool.reserveAllGPUContexts();
+//                     long availableMemory = 
gCtxs.get(0).getAvailableMemory();
+//                     double averageMemoryPerOperand = availableMemory / 3.0;
+//                     MAX_OP_SIZE = Math.min(averageMemoryPerOperand, MAX);
+//                     GPUContextPool.freeAllGPUContexts();
+//             } catch (DMLRuntimeException e) {
+//                     throw new RuntimeException(e);
+//             }
+//
+//     }
 
        private final int seed = 42;
 
@@ -74,28 +75,20 @@ public class NeuralNetworkOpTests extends GPUTests {
        private final List<Integer> Nlst = Arrays.asList(128, 64, 32);
     private final List<Integer> Clst = Arrays.asList(30, 20, 3);
     private final List<Integer> Hlst = Arrays.asList(400, 128, 32);
-    private final List<Integer> Wlst = Arrays.asList(400, 128, 32);
     private final List<Integer> Klst = Arrays.asList(30, 20, 10);
     private final List<Integer> Rlst = Arrays.asList(128, 63, 4);
-    private final List<Integer> Slst = Arrays.asList(128, 63, 4);
-    private final List<Integer> strideHeightLst = Arrays.asList(9, 3);
-    private final List<Integer> strideWidthLst = Arrays.asList(9, 3);
-    private final List<Integer> padHeightLst = Arrays.asList(3, 1);
-    private final List<Integer> padWidthLst = Arrays.asList(3, 1);
-    private final List<Double> sparsitylst = Arrays.asList(1.0);    // Only 
test for dense
-    */
-       private final List<Integer> Nlst = Arrays.asList(128, 64);
-       private final List<Integer> Clst = Arrays.asList(30, 3);
-       private final List<Integer> Hlst = Arrays.asList(256, 64);
-       private final List<Integer> Wlst = Arrays.asList(256, 64);
-       private final List<Integer> Klst = Arrays.asList(30, 20);
-       private final List<Integer> Rlst = Arrays.asList(128, 3);
-       private final List<Integer> Slst = Arrays.asList(128, 3);
-       private final List<Integer> strideHeightLst = Arrays.asList(9, 1);
-       private final List<Integer> strideWidthLst = Arrays.asList(9, 1);
-       private final List<Integer> padHeightLst = Arrays.asList(3, 1);
-       private final List<Integer> padWidthLst = Arrays.asList(3, 1);
-       private final List<Double> sparsitylst = Arrays.asList(1.0);    // Only 
test for dense
+    private final List<Integer> strideLst = Arrays.asList(9, 3);
+    private final List<Integer> padLst = Arrays.asList(3, 1);
+    private final List<Double> sparsitylst = Arrays.asList(1.0, 0.1, 0.3);
+        */
+       private final List<Integer> Nlst = Arrays.asList(32);
+       private final List<Integer> Clst = Arrays.asList(3);
+       private final List<Integer> Hlst = Arrays.asList(64);
+       private final List<Integer> Klst = Arrays.asList(3);
+       private final List<Integer> Rlst = Arrays.asList(3,5);
+       private final List<Integer> strideLst = Arrays.asList(1, 2);
+       private final List<Integer> padLst = Arrays.asList(0,1);
+       private final List<Double> sparsitylst = Arrays.asList(1.0, 0.1, 0.3);
 
        @Override
        public void setUp() {
@@ -109,7 +102,6 @@ public class NeuralNetworkOpTests extends GPUTests {
                return 1e-5;
        }
 
-       @Ignore
        @Test
        public void testConv2d() {
                String scriptStr = "O = conv2d(image, filter, padding=[padH, 
padW], stride=[strideH, strideW], input_shape=[N,C,H,W], 
filter_shape=[K,C,R,S])";
@@ -117,75 +109,74 @@ public class NeuralNetworkOpTests extends GPUTests {
                for (long N : Nlst) {
                        for (long C : Clst) {
                                for (long H : Hlst) {
-                                       for (long W : Wlst) {
-                                               for (long K : Klst) {
-                                                       for (long R : Rlst) {
-                                                               for (long S : 
Slst) {
-                                                                       for 
(long strideH : strideHeightLst) {
-                                                                               
for (long strideW : strideWidthLst) {
-                                                                               
        for (long padH : padHeightLst) {
-                                                                               
                for (long padW : padWidthLst) {
-                                                                               
                        for (double sparsity : sparsitylst) {
-
-                                                                               
                                // Make sure ops fit in GPU memory and within 
constraints of cudnn
-                                                                               
                                long imageSize = N * C * H * W * 8l;
-                                                                               
                                if (imageSize > MAX_OP_SIZE)  // image size
-                                                                               
                                        continue;
-                                                                               
                                long filterSize = K * C * R * S * 8l;
-                                                                               
                                if (filterSize > MAX_OP_SIZE)  // filter size
-                                                                               
                                        continue;
-                                                                               
                                // filter is smaller than image + padding
-                                                                               
                                if (R > (H + padH) || S > (W + padW))
-                                                                               
                                        continue;
-
-                                                                               
                                int P = (int) ConvolutionUtils.getP(H, R, 
strideH, padH);
-                                                                               
                                int Q = (int) ConvolutionUtils.getQ(W, S, 
strideW, padW);
-
-                                                                               
                                long doutSize = N * K * P * Q * 8l;
-                                                                               
                                if (doutSize > MAX_OP_SIZE) // dout/output size
-                                                                               
                                        continue;
-
-                                                                               
                                double imageSizeInMB = imageSize / (1024.0 * 
1024.0);
-                                                                               
                                double filterSizeInMB = filterSize / (1024.0 * 
1024.0);
-                                                                               
                                double doutSizeInMB = doutSize / (1024.0 * 
1024.0);
-                                                                               
                                System.out
-                                                                               
                                                .format("conv2d, 
image[%d,%d,%d,%d](%.1fMB), filter[%d,%d,%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
-                                                                               
                                                                N, C, H, W, 
imageSizeInMB, N, C, R, S,
-                                                                               
                                                                filterSizeInMB, 
N, K, P, Q, doutSizeInMB,
-                                                                               
                                                                strideH, 
strideW, padH, padW);
-                                                                               
                                Matrix image = generateInputMatrix(spark, (int) 
N,
-                                                                               
                                                (int) (C * H * W), -127, 127, 
sparsity, seed);
-                                                                               
                                Matrix filter = generateInputMatrix(spark, 
(int) K,
-                                                                               
                                                (int) (C * R * S), -127, 127, 
sparsity, seed);
-                                                                               
                                HashMap<String, Object> inputs = new 
HashMap<>();
-                                                                               
                                inputs.put("N", N);
-                                                                               
                                inputs.put("C", C);
-                                                                               
                                inputs.put("H", H);
-                                                                               
                                inputs.put("W", W);
-                                                                               
                                inputs.put("K", K);
-                                                                               
                                inputs.put("R", R);
-                                                                               
                                inputs.put("S", S);
-                                                                               
                                inputs.put("strideH", strideH);
-                                                                               
                                inputs.put("strideW", strideW);
-                                                                               
                                inputs.put("padH", padH);
-                                                                               
                                inputs.put("padW", padW);
-                                                                               
                                inputs.put("image", image);
-                                                                               
                                inputs.put("filter", filter);
-                                                                               
                                List<Object> outCPU = runOnCPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                List<Object> outGPU = runOnGPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                assertHeavyHitterPresent("gpu_conv2d");
-                                                                               
                                assertEqualObjects(outCPU.get(0), 
outGPU.get(0));
-                                                                               
                                clearGPUMemory();
-                                                                               
                        }
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
+                                       long W = H;
+                                       for (long K : Klst) {
+                                               for (long R : Rlst) {
+                                                       long S = R;
+                                                       for (long strideH : 
strideLst) {
+                                                               long strideW = 
strideH;
+                                                               for (long padH 
: padLst) {
+                                                                       long 
padW = padH;
+                                                                       for 
(double sparsity : sparsitylst) {
+
+                                                                               
// Make sure ops fit in GPU memory and within constraints of cudnn
+                                                                               
long imageSize = N * C * H * W * 8l;
+                                                                               
if (imageSize > MAX_OP_SIZE)  // image size
+                                                                               
        continue;
+                                                                               
long filterSize = K * C * R * S * 8l;
+                                                                               
if (filterSize > MAX_OP_SIZE)  // filter size
+                                                                               
        continue;
+                                                                               
// filter is smaller than image + padding
+                                                                               
if (R > (H + padH) || S > (W + padW))
+                                                                               
        continue;
+
+                                                                               
int P = (int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                               
int Q = (int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                               
long doutSize = N * K * P * Q * 8l;
+                                                                               
if (doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
        continue;
+
+                                                                               
double imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                               
double filterSizeInMB = filterSize / (1024.0 * 1024.0);
+                                                                               
double doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                               
System.out
+                                                                               
.format("conv2d, image[%d,%d,%d,%d](%.1fMB), filter[%d,%d,%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
+                                                                               
                N, C, H, W, imageSizeInMB, N, C, R, S,
+                                                                               
                filterSizeInMB, N, K, P, Q, doutSizeInMB,
+                                                                               
                strideH, strideW, padH, padW);
+                                                                               
Matrix image = generateInputMatrix(spark, (int) N,
+                                                                               
                (int) (C * H * W), -127, 127, sparsity, seed, true);
+                                                                               
Matrix filter = generateInputMatrix(spark, (int) K,
+                                                                               
                (int) (C * R * S), -127, 127, sparsity, seed, true);
+                                                                               
HashMap<String, Object> inputs = new HashMap<>();
+                                                                               
inputs.put("N", N);
+                                                                               
inputs.put("C", C);
+                                                                               
inputs.put("H", H);
+                                                                               
inputs.put("W", W);
+                                                                               
inputs.put("K", K);
+                                                                               
inputs.put("R", R);
+                                                                               
inputs.put("S", S);
+                                                                               
inputs.put("strideH", strideH);
+                                                                               
inputs.put("strideW", strideW);
+                                                                               
inputs.put("padH", padH);
+                                                                               
inputs.put("padW", padW);
+                                                                               
inputs.put("image", image);
+                                                                               
inputs.put("filter", filter);
+                                                                               
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
assertHeavyHitterPresent("gpu_conv2d");
+                                                                               
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                               
clearGPUMemory();
                                                                        }
                                                                }
                                                        }
                                                }
+
+
+
                                        }
                                }
                        }
@@ -237,11 +228,11 @@ public class NeuralNetworkOpTests extends GPUTests {
                double filterSizeInMB = filterSize / (1024.0 * 1024.0);
                double doutSizeInMB = doutSize / (1024.0 * 1024.0);
                System.out
-                               .format("conv2d, image[%d,%d,%d,%d](%.1fMB), 
filter[%d,%d,%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], 
padding[%d,%d]",
-                                               N, C, H, W, imageSizeInMB, N, 
C, R, S, filterSizeInMB, N, K, P, Q, doutSizeInMB, strideH,
-                                               strideW, padH, padW);
-               Matrix image = generateInputMatrix(spark, (int) N, (int) (C * H 
* W), -1, 1, sparsity, seed);
-               Matrix filter = generateInputMatrix(spark, (int) K, (int) (C * 
R * S), -1, 1.0, sparsity, seed);
+               .format("conv2d, image[%d,%d,%d,%d](%.1fMB), 
filter[%d,%d,%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], 
padding[%d,%d]",
+                               N, C, H, W, imageSizeInMB, N, C, R, S, 
filterSizeInMB, N, K, P, Q, doutSizeInMB, strideH,
+                               strideW, padH, padW);
+               Matrix image = generateInputMatrix(spark, (int) N, (int) (C * H 
* W), -1, 1, sparsity, seed, true);
+               Matrix filter = generateInputMatrix(spark, (int) K, (int) (C * 
R * S), -1, 1.0, sparsity, seed, true);
                HashMap<String, Object> inputs = new HashMap<>();
                inputs.put("N", N);
                inputs.put("C", C);
@@ -263,7 +254,6 @@ public class NeuralNetworkOpTests extends GPUTests {
                clearGPUMemory();
        }
 
-       @Ignore
        @Test
        public void testConv2dBackwardFilter() {
                String scriptStr = "O = conv2d_backward_filter(image, dout, 
padding=[padH, padW], stride=[strideH, strideW], input_shape=[N,C,H,W], 
filter_shape=[K,C,R,S])";
@@ -271,83 +261,82 @@ public class NeuralNetworkOpTests extends GPUTests {
                for (long N : Nlst) {
                        for (long C : Clst) {
                                for (long H : Hlst) {
-                                       for (long W : Wlst) {
-                                               for (long K : Klst) {
-                                                       for (long R : Rlst) {
-                                                               for (long S : 
Slst) {
-                                                                       for 
(long strideH : strideHeightLst) {
-                                                                               
for (long strideW : strideWidthLst) {
-                                                                               
        for (long padH : padHeightLst) {
-                                                                               
                for (long padW : padWidthLst) {
-                                                                               
                        for (double sparsity : sparsitylst) {
-
-                                                                               
                                // filter is smaller than image + padding
-                                                                               
                                if (R > (H + padH) || S > (W + padW))
-                                                                               
                                        continue;
-
-                                                                               
                                // Make sure ops fit in GPU memory and within 
constraints of cudnn
-                                                                               
                                long imageSize = N * C * H * W * 8l;
-                                                                               
                                if (imageSize > MAX_OP_SIZE)  // image size
-                                                                               
                                        continue;
-                                                                               
                                long filterSize = K * C * R * S * 8l;
-                                                                               
                                if (filterSize > MAX_OP_SIZE)  // filter size
-                                                                               
                                        continue;
-
-                                                                               
                                int P = (int) ConvolutionUtils.getP(H, R, 
strideH, padH);
-                                                                               
                                int Q = (int) ConvolutionUtils.getQ(W, S, 
strideW, padW);
-
-                                                                               
                                long doutSize = N * K * P * Q * 8l;
-                                                                               
                                if (doutSize > MAX_OP_SIZE) // dout/output size
-                                                                               
                                        continue;
-
-                                                                               
                                double imageSizeInMB = imageSize / (1024.0 * 
1024.0);
-                                                                               
                                double filterSizeInMB = filterSize / (1024.0 * 
1024.0);
-                                                                               
                                double doutSizeInMB = doutSize / (1024.0 * 
1024.0);
-                                                                               
                                System.out
-                                                                               
                                                
.format("conv2d_backward_filter, image[%d,%d,%d,%d](%.1fMB), 
filter[%d,%d,%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], 
padding[%d,%d]",
-                                                                               
                                                                N, C, H, W, 
imageSizeInMB, N, C, R, S,
-                                                                               
                                                                filterSizeInMB, 
N, K, P, Q, doutSizeInMB,
-                                                                               
                                                                strideH, 
strideW, padH, padW);
-                                                                               
                                Matrix image = generateInputMatrix(spark, (int) 
N,
-                                                                               
                                                (int) (C * H * W), -127.0, 127, 
sparsity, seed);
-                                                                               
                                Matrix dout = generateInputMatrix(spark, (int) 
N,
-                                                                               
                                                (int) (K * P * Q), -127.0, 127, 
sparsity, seed);
-                                                                               
                                HashMap<String, Object> inputs = new 
HashMap<>();
-                                                                               
                                inputs.put("N", N);
-                                                                               
                                inputs.put("C", C);
-                                                                               
                                inputs.put("H", H);
-                                                                               
                                inputs.put("W", W);
-                                                                               
                                inputs.put("K", K);
-                                                                               
                                inputs.put("R", R);
-                                                                               
                                inputs.put("S", S);
-                                                                               
                                inputs.put("strideH", strideH);
-                                                                               
                                inputs.put("strideW", strideW);
-                                                                               
                                inputs.put("padH", padH);
-                                                                               
                                inputs.put("padW", padW);
-                                                                               
                                inputs.put("image", image);
-                                                                               
                                inputs.put("dout", dout);
-                                                                               
                                List<Object> outCPU = runOnCPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                List<Object> outGPU = runOnGPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                
assertHeavyHitterPresent("gpu_conv2d_backward_filter");
-                                                                               
                                assertEqualObjects(outCPU.get(0), 
outGPU.get(0));
-                                                                               
                                clearGPUMemory();
-                                                                               
                        }
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
+                                       long W = H;
+                                       for (long K : Klst) {
+                                               for (long R : Rlst) {
+                                                       long S = R;
+                                                       for (long strideH : 
strideLst) {
+                                                               long strideW = 
strideH;
+                                                               for (long padH 
: padLst) {
+                                                                       long 
padW = padH;
+                                                                       for 
(double sparsity : sparsitylst) {
+
+                                                                               
// filter is smaller than image + padding
+                                                                               
if (R > (H + padH) || S > (W + padW))
+                                                                               
        continue;
+
+                                                                               
// Make sure ops fit in GPU memory and within constraints of cudnn
+                                                                               
long imageSize = N * C * H * W * 8l;
+                                                                               
if (imageSize > MAX_OP_SIZE)  // image size
+                                                                               
        continue;
+                                                                               
long filterSize = K * C * R * S * 8l;
+                                                                               
if (filterSize > MAX_OP_SIZE)  // filter size
+                                                                               
        continue;
+
+                                                                               
int P = (int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                               
int Q = (int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                               
long doutSize = N * K * P * Q * 8l;
+                                                                               
if (doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
        continue;
+
+                                                                               
double imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                               
double filterSizeInMB = filterSize / (1024.0 * 1024.0);
+                                                                               
double doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                               
System.out
+                                                                               
.format("conv2d_backward_filter, image[%d,%d,%d,%d](%.1fMB), 
filter[%d,%d,%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], 
padding[%d,%d]",
+                                                                               
                N, C, H, W, imageSizeInMB, N, C, R, S,
+                                                                               
                filterSizeInMB, N, K, P, Q, doutSizeInMB,
+                                                                               
                strideH, strideW, padH, padW);
+                                                                               
Matrix image = generateInputMatrix(spark, (int) N,
+                                                                               
                (int) (C * H * W), -127.0, 127, sparsity, seed, true);
+                                                                               
Matrix dout = generateInputMatrix(spark, (int) N,
+                                                                               
                (int) (K * P * Q), -127.0, 127, sparsity, seed, true);
+                                                                               
HashMap<String, Object> inputs = new HashMap<>();
+                                                                               
inputs.put("N", N);
+                                                                               
inputs.put("C", C);
+                                                                               
inputs.put("H", H);
+                                                                               
inputs.put("W", W);
+                                                                               
inputs.put("K", K);
+                                                                               
inputs.put("R", R);
+                                                                               
inputs.put("S", S);
+                                                                               
inputs.put("strideH", strideH);
+                                                                               
inputs.put("strideW", strideW);
+                                                                               
inputs.put("padH", padH);
+                                                                               
inputs.put("padW", padW);
+                                                                               
inputs.put("image", image);
+                                                                               
inputs.put("dout", dout);
+                                                                               
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
assertHeavyHitterPresent("gpu_conv2d_backward_filter");
+                                                                               
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                               
clearGPUMemory();
                                                                        }
                                                                }
                                                        }
                                                }
+
+
+
+
                                        }
                                }
                        }
                }
        }
 
-       @Ignore
        @Test
        public void testConv2dBackwardData() {
                String scriptStr = "O = conv2d_backward_data(filter, dout, 
padding=[padH, padW], stride=[strideH, strideW], input_shape=[N,C,H,W], 
filter_shape=[K,C,R,S])";
@@ -355,237 +344,237 @@ public class NeuralNetworkOpTests extends GPUTests {
                for (long N : Nlst) {
                        for (long C : Clst) {
                                for (long H : Hlst) {
-                                       for (long W : Wlst) {
-                                               for (long K : Klst) {
-                                                       for (long R : Rlst) {
-                                                               for (long S : 
Slst) {
-                                                                       for 
(long strideH : strideHeightLst) {
-                                                                               
for (long strideW : strideWidthLst) {
-                                                                               
        for (long padH : padHeightLst) {
-                                                                               
                for (long padW : padWidthLst) {
-                                                                               
                        for (double sparsity : sparsitylst) {
-
-                                                                               
                                // filter is smaller than image + padding
-                                                                               
                                if (R > (H + padH) || S > (W + padW))
-                                                                               
                                        continue;
-
-                                                                               
                                // Make sure ops fit in GPU memory and within 
constraints of cudnn
-                                                                               
                                long imageSize = N * C * H * W * 8l;
-                                                                               
                                if (imageSize > MAX_OP_SIZE)  // image size
-                                                                               
                                        continue;
-                                                                               
                                long filterSize = K * C * R * S * 8l;
-                                                                               
                                if (filterSize > MAX_OP_SIZE)  // filter size
-                                                                               
                                        continue;
-
-                                                                               
                                int P = (int) ConvolutionUtils.getP(H, R, 
strideH, padH);
-                                                                               
                                int Q = (int) ConvolutionUtils.getQ(W, S, 
strideW, padW);
-
-                                                                               
                                long doutSize = N * K * P * Q * 8l;
-                                                                               
                                if (doutSize > MAX_OP_SIZE) // dout/output size
-                                                                               
                                        continue;
-
-                                                                               
                                double imageSizeInMB = imageSize / (1024.0 * 
1024.0);
-                                                                               
                                double filterSizeInMB = filterSize / (1024.0 * 
1024.0);
-                                                                               
                                double doutSizeInMB = doutSize / (1024.0 * 
1024.0);
-                                                                               
                                System.out
-                                                                               
                                                .format("conv2d_backward_data, 
image[%d,%d,%d,%d](%.1fMB), filter[%d,%d,%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
-                                                                               
                                                                N, C, H, W, 
imageSizeInMB, N, C, R, S,
-                                                                               
                                                                filterSizeInMB, 
N, K, P, Q, doutSizeInMB,
-                                                                               
                                                                strideH, 
strideW, padH, padW);
-
-                                                                               
                                Matrix filter = generateInputMatrix(spark, 
(int) K,
-                                                                               
                                                (int) (C * R * S), -127.0, 127, 
sparsity, seed);
-                                                                               
                                Matrix dout = generateInputMatrix(spark, (int) 
N,
-                                                                               
                                                (int) (K * P * Q), -127.0, 127, 
sparsity, seed);
-                                                                               
                                HashMap<String, Object> inputs = new 
HashMap<>();
-                                                                               
                                inputs.put("N", N);
-                                                                               
                                inputs.put("C", C);
-                                                                               
                                inputs.put("H", H);
-                                                                               
                                inputs.put("W", W);
-                                                                               
                                inputs.put("K", K);
-                                                                               
                                inputs.put("R", R);
-                                                                               
                                inputs.put("S", S);
-                                                                               
                                inputs.put("strideH", strideH);
-                                                                               
                                inputs.put("strideW", strideW);
-                                                                               
                                inputs.put("padH", padH);
-                                                                               
                                inputs.put("padW", padW);
-                                                                               
                                inputs.put("filter", filter);
-                                                                               
                                inputs.put("dout", dout);
-                                                                               
                                List<Object> outCPU = runOnCPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                List<Object> outGPU = runOnGPU(spark, 
scriptStr, inputs,
-                                                                               
                                                Arrays.asList("O"));
-                                                                               
                                
assertHeavyHitterPresent("gpu_conv2d_backward_data");
-                                                                               
                                assertEqualObjects(outCPU.get(0), 
outGPU.get(0));
-                                                                               
                                clearGPUMemory();
-                                                                               
                        }
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
+                                       long W = H;
+                                       for (long K : Klst) {
+                                               for (long R : Rlst) {
+                                                       long S = R;
+                                                       for (long strideH : 
strideLst) {
+                                                               long strideW = 
strideH;
+                                                               for (long padH 
: padLst) {
+                                                                       long 
padW = padH;
+                                                                       for 
(double sparsity : sparsitylst) {
+
+                                                                               
// filter is smaller than image + padding
+                                                                               
if (R > (H + padH) || S > (W + padW))
+                                                                               
        continue;
+
+                                                                               
// Make sure ops fit in GPU memory and within constraints of cudnn
+                                                                               
long imageSize = N * C * H * W * 8l;
+                                                                               
if (imageSize > MAX_OP_SIZE)  // image size
+                                                                               
        continue;
+                                                                               
long filterSize = K * C * R * S * 8l;
+                                                                               
if (filterSize > MAX_OP_SIZE)  // filter size
+                                                                               
        continue;
+
+                                                                               
int P = (int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                               
int Q = (int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                               
long doutSize = N * K * P * Q * 8l;
+                                                                               
if (doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
        continue;
+
+                                                                               
double imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                               
double filterSizeInMB = filterSize / (1024.0 * 1024.0);
+                                                                               
double doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                               
System.out
+                                                                               
.format("conv2d_backward_data, image[%d,%d,%d,%d](%.1fMB), 
filter[%d,%d,%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], 
padding[%d,%d]",
+                                                                               
                N, C, H, W, imageSizeInMB, N, C, R, S,
+                                                                               
                filterSizeInMB, N, K, P, Q, doutSizeInMB,
+                                                                               
                strideH, strideW, padH, padW);
+
+                                                                               
Matrix filter = generateInputMatrix(spark, (int) K,
+                                                                               
                (int) (C * R * S), -127.0, 127, sparsity, seed, true);
+                                                                               
Matrix dout = generateInputMatrix(spark, (int) N,
+                                                                               
                (int) (K * P * Q), -127.0, 127, sparsity, seed, true);
+                                                                               
HashMap<String, Object> inputs = new HashMap<>();
+                                                                               
inputs.put("N", N);
+                                                                               
inputs.put("C", C);
+                                                                               
inputs.put("H", H);
+                                                                               
inputs.put("W", W);
+                                                                               
inputs.put("K", K);
+                                                                               
inputs.put("R", R);
+                                                                               
inputs.put("S", S);
+                                                                               
inputs.put("strideH", strideH);
+                                                                               
inputs.put("strideW", strideW);
+                                                                               
inputs.put("padH", padH);
+                                                                               
inputs.put("padW", padW);
+                                                                               
inputs.put("filter", filter);
+                                                                               
inputs.put("dout", dout);
+                                                                               
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
                Arrays.asList("O"));
+                                                                               
assertHeavyHitterPresent("gpu_conv2d_backward_data");
+                                                                               
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                               
clearGPUMemory();
                                                                        }
                                                                }
                                                        }
                                                }
                                        }
+
+
+
+
                                }
                        }
                }
        }
 
-       @Ignore
        @Test
+       @Ignore
        public void testMaxPool() {
                String scriptStr = "O = max_pool(image, padding=[padH, padW], 
stride=[strideH, strideW], input_shape=[N,C,H,W], pool_size=[R,S])";
 
                for (long N : Nlst) {
                        for (long C : Clst) {
                                for (long H : Hlst) {
-                                       for (long W : Wlst) {
-                                               for (long R : Rlst) {
-                                                       for (long S : Slst) {
-                                                               for (long 
strideH : strideHeightLst) {
-                                                                       for 
(long strideW : strideWidthLst) {
-                                                                               
for (long padH : padHeightLst) {
-                                                                               
        for (long padW : padWidthLst) {
-                                                                               
                for (double sparsity : sparsitylst) {
-
-                                                                               
                        // pool is smaller than image + padding
-                                                                               
                        if (R > (H + padH) || S > (W + padW))
-                                                                               
                                continue;
-
-                                                                               
                        // Make sure ops fit in GPU memory and within 
constraints of cudnn
-                                                                               
                        long imageSize = N * C * H * W * 8l;
-                                                                               
                        if (imageSize > MAX_OP_SIZE)  // image size
-                                                                               
                                continue;
-                                                                               
                        long poolSize = R * S * 8l;
-                                                                               
                        if (poolSize > MAX_OP_SIZE)  // filter size
-                                                                               
                                continue;
-
-                                                                               
                        int P = (int) ConvolutionUtils.getP(H, R, strideH, 
padH);
-                                                                               
                        int Q = (int) ConvolutionUtils.getQ(W, S, strideW, 
padW);
-
-                                                                               
                        long doutSize = N * C * P * Q * 8l;
-                                                                               
                        if (doutSize > MAX_OP_SIZE) // dout/output size
-                                                                               
                                continue;
-
-                                                                               
                        double imageSizeInMB = imageSize / (1024.0 * 1024.0);
-                                                                               
                        double poolSizeInMB = poolSize / (1024.0 * 1024.0);
-                                                                               
                        double doutSizeInMB = doutSize / (1024.0 * 1024.0);
-                                                                               
                        System.out
-                                                                               
                                        .format("max_pool, 
image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), 
stride[%d,%d], padding[%d,%d]",
-                                                                               
                                                        N, C, H, W, 
imageSizeInMB, R, S, poolSizeInMB, N, C,
-                                                                               
                                                        P, Q, doutSizeInMB, 
strideH, strideW, padH, padW);
-
-                                                                               
                        Matrix image = generateInputMatrix(spark, (int) N,
-                                                                               
                                        (int) (C * H * W), -127.0, 127, 
sparsity, seed);
-                                                                               
                        HashMap<String, Object> inputs = new HashMap<>();
-                                                                               
                        inputs.put("N", N);
-                                                                               
                        inputs.put("C", C);
-                                                                               
                        inputs.put("H", H);
-                                                                               
                        inputs.put("W", W);
-                                                                               
                        inputs.put("R", R);
-                                                                               
                        inputs.put("S", S);
-                                                                               
                        inputs.put("strideH", strideH);
-                                                                               
                        inputs.put("strideW", strideW);
-                                                                               
                        inputs.put("padH", padH);
-                                                                               
                        inputs.put("padW", padW);
-                                                                               
                        inputs.put("image", image);
-                                                                               
                        List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
-                                                                               
                                        Arrays.asList("O"));
-                                                                               
                        List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
-                                                                               
                                        Arrays.asList("O"));
-                                                                               
                        assertHeavyHitterPresent("gpu_maxpooling");
-                                                                               
                        assertEqualObjects(outCPU.get(0), outGPU.get(0));
-                                                                               
                        clearGPUMemory();
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
-                                                                       }
+                                       long W = H;
+                                       for (long R : Rlst) {
+                                               long S = R;
+                                               for (long strideH : strideLst) {
+                                                       long strideW = strideH;
+                                                       for (long padH : 
padLst) {
+                                                               long padW = 
padH;
+                                                               for (double 
sparsity : sparsitylst) {
+
+                                                                       // pool 
is smaller than image + padding
+                                                                       if (R > 
(H + padH) || S > (W + padW))
+                                                                               
continue;
+
+                                                                       // Make 
sure ops fit in GPU memory and within constraints of cudnn
+                                                                       long 
imageSize = N * C * H * W * 8l;
+                                                                       if 
(imageSize > MAX_OP_SIZE)  // image size
+                                                                               
continue;
+                                                                       long 
poolSize = R * S * 8l;
+                                                                       if 
(poolSize > MAX_OP_SIZE)  // filter size
+                                                                               
continue;
+
+                                                                       int P = 
(int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                       int Q = 
(int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                       long 
doutSize = N * C * P * Q * 8l;
+                                                                       if 
(doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
continue;
+
+                                                                       double 
imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                       double 
poolSizeInMB = poolSize / (1024.0 * 1024.0);
+                                                                       double 
doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                       
System.out
+                                                                       
.format("max_pool, image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
+                                                                               
        N, C, H, W, imageSizeInMB, R, S, poolSizeInMB, N, C,
+                                                                               
        P, Q, doutSizeInMB, strideH, strideW, padH, padW);
+
+                                                                       Matrix 
image = generateInputMatrix(spark, (int) N,
+                                                                               
        (int) (C * H * W), -127.0, 127, sparsity, seed, true);
+                                                                       
HashMap<String, Object> inputs = new HashMap<>();
+                                                                       
inputs.put("N", N);
+                                                                       
inputs.put("C", C);
+                                                                       
inputs.put("H", H);
+                                                                       
inputs.put("W", W);
+                                                                       
inputs.put("R", R);
+                                                                       
inputs.put("S", S);
+                                                                       
inputs.put("strideH", strideH);
+                                                                       
inputs.put("strideW", strideW);
+                                                                       
inputs.put("padH", padH);
+                                                                       
inputs.put("padW", padW);
+                                                                       
inputs.put("image", image);
+                                                                       
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
assertHeavyHitterPresent("gpu_maxpooling");
+                                                                       
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                       
clearGPUMemory();
                                                                }
                                                        }
                                                }
                                        }
+
+
+
+
                                }
                        }
                }
        }
 
-       @Ignore
        @Test
+       @Ignore
        public void testMaxPoolBackward() {
                String scriptStr = "O = max_pool_backward(image, dout, 
padding=[padH, padW], stride=[strideH, strideW], input_shape=[N,C,H,W], 
pool_size=[R,S])";
 
                for (long N : Nlst) {
                        for (long C : Clst) {
                                for (long H : Hlst) {
-                                       for (long W : Wlst) {
-                                               for (long R : Rlst) {
-                                                       for (long S : Slst) {
-                                                               for (long 
strideH : strideHeightLst) {
-                                                                       for 
(long strideW : strideWidthLst) {
-                                                                               
for (long padH : padHeightLst) {
-                                                                               
        for (long padW : padWidthLst) {
-                                                                               
                for (double sparsity : sparsitylst) {
-
-                                                                               
                        // pool is smaller than image + padding
-                                                                               
                        if (R > (H + padH) || S > (W + padW))
-                                                                               
                                continue;
-
-                                                                               
                        // Make sure ops fit in GPU memory and within 
constraints of cudnn
-                                                                               
                        long imageSize = N * C * H * W * 8l;
-                                                                               
                        if (imageSize > MAX_OP_SIZE)  // image size
-                                                                               
                                continue;
-                                                                               
                        long poolSize = R * S * 8l;
-                                                                               
                        if (poolSize > MAX_OP_SIZE)  // filter size
-                                                                               
                                continue;
-
-                                                                               
                        int P = (int) ConvolutionUtils.getP(H, R, strideH, 
padH);
-                                                                               
                        int Q = (int) ConvolutionUtils.getQ(W, S, strideW, 
padW);
-
-                                                                               
                        long doutSize = N * C * P * Q * 8l;
-                                                                               
                        if (doutSize > MAX_OP_SIZE) // dout/output size
-                                                                               
                                continue;
-
-                                                                               
                        double imageSizeInMB = imageSize / (1024.0 * 1024.0);
-                                                                               
                        double poolSizeInMB = poolSize / (1024.0 * 1024.0);
-                                                                               
                        double doutSizeInMB = doutSize / (1024.0 * 1024.0);
-                                                                               
                        System.out
-                                                                               
                                        .format("max_pool_backward, 
image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), dout[%d,%d,%d,%d](%.1fMB), 
stride[%d,%d], padding[%d,%d]",
-                                                                               
                                                        N, C, H, W, 
imageSizeInMB, R, S, poolSizeInMB, N, C,
-                                                                               
                                                        P, Q, doutSizeInMB, 
strideH, strideW, padH, padW);
-
-                                                                               
                        Matrix image = generateInputMatrix(spark, (int) N,
-                                                                               
                                        (int) (C * H * W), -127.0, 127, 
sparsity, seed);
-                                                                               
                        Matrix dout = generateInputMatrix(spark, (int) N, (int) 
(C * P * Q),
-                                                                               
                                        -127.0, 127, sparsity, seed);
-                                                                               
                        HashMap<String, Object> inputs = new HashMap<>();
-                                                                               
                        inputs.put("N", N);
-                                                                               
                        inputs.put("C", C);
-                                                                               
                        inputs.put("H", H);
-                                                                               
                        inputs.put("W", W);
-                                                                               
                        inputs.put("R", R);
-                                                                               
                        inputs.put("S", S);
-                                                                               
                        inputs.put("strideH", strideH);
-                                                                               
                        inputs.put("strideW", strideW);
-                                                                               
                        inputs.put("padH", padH);
-                                                                               
                        inputs.put("padW", padW);
-                                                                               
                        inputs.put("image", image);
-                                                                               
                        inputs.put("dout", dout);
-                                                                               
                        List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
-                                                                               
                                        Arrays.asList("O"));
-                                                                               
                        List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
-                                                                               
                                        Arrays.asList("O"));
-                                                                               
                        assertHeavyHitterPresent("gpu_maxpooling_backward");
-                                                                               
                        assertEqualObjects(outCPU.get(0), outGPU.get(0));
-                                                                               
                        clearGPUMemory();
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
-                                                                       }
+                                       long W = H;
+                                       for (long R : Rlst) {
+                                               long S = R;
+                                               for (long strideH : strideLst) {
+                                                       long strideW = strideH;
+                                                       for (long padH : 
padLst) {
+                                                               long padW = 
padH;
+                                                               for (double 
sparsity : sparsitylst) {
+
+                                                                       // pool 
is smaller than image + padding
+                                                                       if (R > 
(H + padH) || S > (W + padW))
+                                                                               
continue;
+
+                                                                       // Make 
sure ops fit in GPU memory and within constraints of cudnn
+                                                                       long 
imageSize = N * C * H * W * 8l;
+                                                                       if 
(imageSize > MAX_OP_SIZE)  // image size
+                                                                               
continue;
+                                                                       long 
poolSize = R * S * 8l;
+                                                                       if 
(poolSize > MAX_OP_SIZE)  // filter size
+                                                                               
continue;
+
+                                                                       int P = 
(int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                       int Q = 
(int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                       long 
doutSize = N * C * P * Q * 8l;
+                                                                       if 
(doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
continue;
+
+                                                                       double 
imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                       double 
poolSizeInMB = poolSize / (1024.0 * 1024.0);
+                                                                       double 
doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                       
System.out
+                                                                       
.format("max_pool_backward, image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
+                                                                               
        N, C, H, W, imageSizeInMB, R, S, poolSizeInMB, N, C,
+                                                                               
        P, Q, doutSizeInMB, strideH, strideW, padH, padW);
+
+                                                                       Matrix 
image = generateInputMatrix(spark, (int) N,
+                                                                               
        (int) (C * H * W), -127.0, 127, sparsity, seed, true);
+                                                                       Matrix 
dout = generateInputMatrix(spark, (int) N, (int) (C * P * Q),
+                                                                               
        -127.0, 127, sparsity, seed, true);
+                                                                       
HashMap<String, Object> inputs = new HashMap<>();
+                                                                       
inputs.put("N", N);
+                                                                       
inputs.put("C", C);
+                                                                       
inputs.put("H", H);
+                                                                       
inputs.put("W", W);
+                                                                       
inputs.put("R", R);
+                                                                       
inputs.put("S", S);
+                                                                       
inputs.put("strideH", strideH);
+                                                                       
inputs.put("strideW", strideW);
+                                                                       
inputs.put("padH", padH);
+                                                                       
inputs.put("padW", padW);
+                                                                       
inputs.put("image", image);
+                                                                       
inputs.put("dout", dout);
+                                                                       
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
assertHeavyHitterPresent("gpu_maxpooling_backward");
+                                                                       
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                       
clearGPUMemory();
                                                                }
                                                        }
                                                }
                                        }
+
+
+
+
                                }
                        }
                }

Reply via email to