Repository: systemml
Updated Branches:
  refs/heads/master e4c74eda6 -> 43b573dfb


[SYSTEMML-445] Select CuDNN operator dynamically based on memory budget

- Refactored CuDNN operator selection logic into separate file to simplify
  LibMatrixCuDNN.
- Fixed a minor memory leak in conv2d backward filter (deallocation of
  descriptor).
- Added optimistic intermediate memory budget for GPU ConvolutionOp.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/43b573df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/43b573df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/43b573df

Branch: refs/heads/master
Commit: 43b573dfbb6632d834ab8fa1695e93ce2c4cf15f
Parents: e4c74ed
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Thu Sep 28 21:01:03 2017 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Thu Sep 28 21:10:15 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    |  14 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |   2 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     | 434 +++++--------------
 .../LibMatrixCuDNNConvolutionAlgorithm.java     | 249 +++++++++++
 .../data/LibMatrixCuDNNInputRowFetcher.java     |  82 ++++
 5 files changed, 456 insertions(+), 325 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/43b573df/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 0ad9182..c5cf667 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -29,6 +29,7 @@ import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 
@@ -191,7 +192,18 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
 //             // TODO: Inserting reblock requires knowing columns apriori
 //             ConvolutionTransform transform1 = new 
ConvolutionTransform(addReblockIfNecessary(et, lopOp, in), lopOp, 
getDataType(), getValueType(), et, k);
 //             setReblockedOutputDimension(et, transform1);
-               ConvolutionTransform transform1 = new ConvolutionTransform(in, 
lopOp, getDataType(), getValueType(), et, k, computeIntermediateMemEstimate(-1, 
-1, -1 ));
+               double cpIntermediateMemEstimate = 
computeIntermediateMemEstimate(-1, -1, -1 );
+               if(et == ExecType.GPU && _dim1 > 0 && _dim2 > 0) {
+                       // This enables us to compile more efficient 
matrix-matrix CuDNN operation instead of 
+                       // row-by-row invocation of multiple vector-matrix 
CuDNN operations.
+                       // This is possible as the operations on GPU are 
single-threaded
+                       double optimisticIntermediateMemEstimate = 
GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate() - 
inputs.get(0).getOutputMemEstimate();
+                       if(in2 != null) {
+                               optimisticIntermediateMemEstimate -= 
inputs.get(1).getOutputMemEstimate();
+                       }
+                       cpIntermediateMemEstimate = 
Math.max(cpIntermediateMemEstimate, optimisticIntermediateMemEstimate);
+               }
+               ConvolutionTransform transform1 = new ConvolutionTransform(in, 
lopOp, getDataType(), getValueType(), et, k, cpIntermediateMemEstimate);
                setOutputDimensions(transform1);
                
                setLineNumbers(transform1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/43b573df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 4b2cd73..f4a00ab 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -2208,7 +2208,7 @@ public class LibMatrixCUDA {
        //******************* End of Re-org Functions ************************/
        //********************************************************************/
 
-       protected static int toInt(long num) throws DMLRuntimeException {
+       static int toInt(long num) throws DMLRuntimeException {
                if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
                        throw new DMLRuntimeException("GPU : Exceeded supported 
size " + num);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/43b573df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 654bd9d..25dc604 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -30,12 +30,6 @@ import static 
jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize;
-import static 
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize;
-import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
 import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
 import static jcuda.jcudnn.JCudnn.cudnnSetActivationDescriptor;
@@ -67,13 +61,11 @@ import jcuda.jcudnn.cudnnTensorDescriptor;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
-import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.utils.GPUStatistics;
@@ -154,28 +146,34 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
                long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
 
-               if(DMLScript.FORCE_ACCELERATOR ||
-                               (NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor)) {
+               if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
                        // Filter and output are accounted as dense in the 
memory estimation for conv2d
                        double overhead = isInSparseFormat(gCtx, filter) ? 
OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
                        overhead += isInSparseFormat(gCtx, image) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
 
                        Pointer filterPointer = getDensePointerForCuDNN(gCtx, 
filter, instName);
                        Pointer dstPointer = getDensePointerForCuDNN(gCtx, 
outputBlock, instName);
-
-                       if(DMLScript.FORCE_ACCELERATOR || overhead <= 
intermediateMemoryBudget) {
-                               // Perform all-input all-channel conv2d
-                               Pointer imagePointer = 
getDensePointerForCuDNN(gCtx, image, instName);
-                               cudnnConv2d(gCtx, instName, imagePointer, 
filterPointer, dstPointer, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
-                       }
-                       else {
-                               InputRowFetcher imgFetcher = new 
InputRowFetcher(gCtx, instName, image);
-                               for(int n = 0; n < N; n++) {
-                                       // Perform one-input all-channel conv2d
-                                       cudnnConv2d(gCtx, instName, 
imgFetcher.getNthRow(n), filterPointer, 
dstPointer.withByteOffset(n*KPQ*Sizeof.DOUBLE), 
-                                                       1, C, H, W, K, R, S, 
pad_h, pad_w, stride_h, stride_w, P, Q);
+                       
+                       // Required for LibMatrixCuDNNConvolutionAlgorithm
+                       long workspaceLimit = (long) 
(intermediateMemoryBudget-overhead);
+                       int localN = overhead <= intermediateMemoryBudget ? N : 
1;
+                       
+                       try(LibMatrixCuDNNConvolutionAlgorithm algo = 
+                                       
LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionForwardAlgorithm(gCtx, 
instName, 
+                                       localN, C, H, W, K, R, S, pad_h, pad_w, 
stride_h, stride_w, P, Q, workspaceLimit)) {
+                               if(localN == N) {
+                                       // Perform all-input all-channel conv2d
+                                       Pointer imagePointer = 
getDensePointerForCuDNN(gCtx, image, instName);
+                                       cudnnConv2d(gCtx, instName, 
imagePointer, filterPointer, dstPointer, algo);
+                               }
+                               else {
+                                       try(LibMatrixCuDNNInputRowFetcher 
imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image)) {
+                                               for(int n = 0; n < N; n++) {
+                                                       // Perform one-input 
all-channel conv2d
+                                                       cudnnConv2d(gCtx, 
instName, imgFetcher.getNthRow(n), filterPointer, 
dstPointer.withByteOffset(n*KPQ*Sizeof.DOUBLE), algo);
+                                               }
+                                       }
                                }
-                               imgFetcher.close();
                        }
                }
                else {
@@ -228,93 +226,31 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param image    the input matrix (or image) allocated on the GPU
         * @param filter   the filter allocated on the GPU
         * @param output   the output matrix allocated on the GPU
-        * @param N        number of input images
-        * @param C        number of channels
-        * @param H        height of each image
-        * @param W        width of each image
-        * @param K        number of output "channels"
-        * @param R        height of filter
-        * @param S        width of filter
-        * @param pad_h    padding height
-        * @param pad_w    padding width
-        * @param stride_h stride height
-        * @param stride_w string width
-        * @param P        output height
-        * @param Q        output width
+        * @param algo     cudnn algorithm wrapper
         * @throws DMLRuntimeException if error
         */
-       private static void cudnnConv2d(GPUContext gCtx, String instName, 
Pointer image, Pointer filter, Pointer output, int N,
-                       int C, int H, int W, int K, int R, int S, int pad_h, 
int pad_w, int stride_h, int stride_w, int P, int Q)
+       private static void cudnnConv2d(GPUContext gCtx, String instName, 
Pointer image, Pointer filter, Pointer output, 
+                       LibMatrixCuDNNConvolutionAlgorithm algo)
                                        throws DMLRuntimeException {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : conv2d" + ", GPUContext=" + gCtx);
                }
-               cudnnFilterDescriptor filterDesc = null;
-               cudnnConvolutionDescriptor convDesc = null;
-               Pointer workSpace = null;
-               long sizeInBytes = 0;
                try {
-                       long t1 = 0, t2 = 0;
-                       // Allocate descriptors
+                       long t1 = 0;
                        if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
-                       cudnnTensorDescriptor srcTensorDesc = 
allocateTensorDescriptor(N, C, H, W);
-                       cudnnTensorDescriptor dstTensorDesc = 
allocateTensorDescriptor(N, K, P, Q);
-                       filterDesc = allocateFilterDescriptor(K, C, R, S);
-
-                       int padding[] = {pad_h, pad_w};
-                       int strides[] = {stride_h, stride_w};
-                       convDesc = allocateConvolutionDescriptor(padding, 
strides);
-
-                       // Select the best algorithm depending on the data and 
supported CUDA
-
-                       int algo = -1;
-                       workSpace = new Pointer();
-
-                       if (CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE) {
-                               algo = 
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-                       } else if (CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_PREFER_FASTEST) {
-                               int[] algos = {-1};
-                               // TODO: Look into FFt, Winograd, etc
-                               // Also ensure that GPU has enough memory to 
allocate memory
-                               long sizeInBytesArray[] = {0};
-                               
jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(getCudnnHandle(gCtx), 
srcTensorDesc, filterDesc, convDesc, dstTensorDesc,
-                                               CONVOLUTION_PREFERENCE, 
sizeInBytesArray[0], algos);
-                               
cudnnGetConvolutionForwardWorkspaceSize(getCudnnHandle(gCtx), srcTensorDesc, 
filterDesc, convDesc, dstTensorDesc, algos[0], sizeInBytesArray);
-                               if (sizeInBytesArray[0] != 0)
-                                       workSpace = 
gCtx.allocate(sizeInBytesArray[0]);
-                               sizeInBytes = sizeInBytesArray[0];
-                       } else if (CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT) {
-                               throw new 
DMLRuntimeException("CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT is not 
implemented");
-                       } else {
-                               throw new DMLRuntimeException("Unsupported 
preference criteria for convolution");
-                       }
-                       if (GPUStatistics.DISPLAY_STATISTICS)
-                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-                       if (GPUStatistics.DISPLAY_STATISTICS) t2 = 
System.nanoTime();
                        int status = 
cudnnConvolutionForward(getCudnnHandle(gCtx), one(),
-                                       srcTensorDesc, image,
-                                       filterDesc, filter,
-                                       convDesc, algo, workSpace, sizeInBytes, 
zero(),
-                                       dstTensorDesc, output);
+                                       algo.nchwTensorDesc, image,
+                                       algo.filterDesc, filter,
+                                       algo.convDesc, algo.algo, 
algo.workSpace, algo.sizeInBytes, zero(),
+                                       algo.nkpqTensorDesc, output);
                        if (GPUStatistics.DISPLAY_STATISTICS)
-                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_FORWARD_LIB, System.nanoTime() - t2);
+                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_FORWARD_LIB, System.nanoTime() - t1);
                        if (status != cudnnStatus.CUDNN_STATUS_SUCCESS) {
                                throw new DMLRuntimeException("Could not 
executed cudnnConvolutionForward: " + cudnnStatus.stringFor(status));
                        }
                } catch (CudaException e) {
                        throw new DMLRuntimeException("Error in conv2d in 
GPUContext " + gCtx.toString() + " from Thread " + 
Thread.currentThread().toString(), e);
-               } finally {
-                       long t3 = 0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t3 = 
System.nanoTime();
-                       if (filterDesc != null)
-                               cudnnDestroyFilterDescriptor(filterDesc);
-                       if (convDesc != null)
-                               cudnnDestroyConvolutionDescriptor(convDesc);
-                       if (workSpace != null && sizeInBytes != 0)
-                               gCtx.cudaFreeHelper(instName, workSpace);
-                       if (GPUStatistics.DISPLAY_STATISTICS)
-                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-               }
+               } 
        }
 
        /**
@@ -347,41 +283,46 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                        int Q, double intermediateMemoryBudget) throws 
DMLRuntimeException {
                long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
                long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
-
-               if(DMLScript.FORCE_ACCELERATOR || 
-                               (NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor)) {
+               
+               
+               if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
                        Pointer dwPointer = getDensePointerForCuDNN(gCtx, 
outputBlock, instName);
                        double overhead = isInSparseFormat(gCtx, image) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
                        overhead += isInSparseFormat(gCtx, dout) ? 
OptimizerUtils.estimateSizeExactSparsity(N, KPQ, 1.0) : 0;
-                       if(DMLScript.FORCE_ACCELERATOR || overhead <= 
intermediateMemoryBudget) {
-                               // Perform all-input all-channel 
conv2dBackwardFilter
-                               Pointer imagePointer = 
getDensePointerForCuDNN(gCtx, image, instName);
-                               Pointer doutPointer = 
getDensePointerForCuDNN(gCtx, dout, instName);
-                               cudnnConv2dBackwardFilter(gCtx, instName, 
imagePointer, doutPointer, dwPointer, 
-                                               N, C, H, W, K, R, S, pad_h, 
pad_w, stride_h, stride_w, P, Q);
-                       }
-                       else {
-                               // Perform one-input conv2dBackwardFilter
-                               Pointer tempdwPointer = 
gCtx.allocate(KCRS*Sizeof.DOUBLE);
-                               InputRowFetcher imgFetcher = new 
InputRowFetcher(gCtx, instName, image);
-                               InputRowFetcher doutFetcher = new 
InputRowFetcher(gCtx, instName, dout);
-                               for(int n = 0; n < N; n++) {
-                                       long t0 = 
GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
-                                       cudaMemset(tempdwPointer, 0, 
KCRS*Sizeof.DOUBLE);
-                                       if(GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, 
System.nanoTime() - t0);
-                                       // Perform one-input 
conv2dBackwardFilter
-                                       cudnnConv2dBackwardFilter(gCtx, 
instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), tempdwPointer, 
-                                                       1, C, H, W, K, R, S, 
pad_h, pad_w, stride_h, stride_w, P, Q);
-                                       
getCudaKernels(gCtx).launchKernel("inplace_add",
-                                                       
ExecutionConfig.getConfigForSimpleMatrixOperations(K, toInt(CRS)),
-                                                       tempdwPointer, 
dwPointer, K, toInt(CRS));
 
+                       // Required for LibMatrixCuDNNConvolutionAlgorithm
+                       long workspaceLimit = (long) 
(intermediateMemoryBudget-overhead);
+                       int localN = overhead <= intermediateMemoryBudget ? N : 
1;
+                       
+                       try(LibMatrixCuDNNConvolutionAlgorithm algo = 
+                                       
LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardFilterAlgorithm(gCtx,
 instName, 
+                                       localN, C, H, W, K, R, S, pad_h, pad_w, 
stride_h, stride_w, P, Q, workspaceLimit)) {
+                               if(localN == N) {
+                                       // Perform all-input all-channel 
conv2dBackwardFilter
+                                       Pointer imagePointer = 
getDensePointerForCuDNN(gCtx, image, instName);
+                                       Pointer doutPointer = 
getDensePointerForCuDNN(gCtx, dout, instName);
+                                       cudnnConv2dBackwardFilter(gCtx, 
instName, imagePointer, doutPointer, dwPointer, algo);
+                               }
+                               else {
+                                       try(LibMatrixCuDNNInputRowFetcher 
imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
+                                               LibMatrixCuDNNInputRowFetcher 
doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout)) {
+                                               // Perform one-input 
conv2dBackwardFilter
+                                               Pointer tempdwPointer = 
gCtx.allocate(KCRS*Sizeof.DOUBLE);
+                                               for(int n = 0; n < N; n++) {
+                                                       long t0 = 
GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() : 0;
+                                                       
cudaMemset(tempdwPointer, 0, KCRS*Sizeof.DOUBLE);
+                                                       
if(GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, 
System.nanoTime() - t0);
+                                                       // Perform one-input 
conv2dBackwardFilter
+                                                       
cudnnConv2dBackwardFilter(gCtx, instName, imgFetcher.getNthRow(n), 
doutFetcher.getNthRow(n), tempdwPointer, algo);
+                                                       
getCudaKernels(gCtx).launchKernel("inplace_add",
+                                                                       
ExecutionConfig.getConfigForSimpleMatrixOperations(K, toInt(CRS)),
+                                                                       
tempdwPointer, dwPointer, K, toInt(CRS));
+
+                                               }
+                                               // Deallocate temporary array 
to hold one element of input
+                                               
gCtx.cudaFreeHelper(tempdwPointer, true);
+                                       }
                                }
-
-                               // Deallocate temporary array to hold one 
element of input
-                               gCtx.cudaFreeHelper(tempdwPointer, true);
-                               imgFetcher.close();
-                               doutFetcher.close();
                        }
                }
                else {
@@ -397,81 +338,26 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param imagePointer pointer to input image
         * @param doutPointer pointer to errors from next layer
         * @param dwPointer  output errors
-        * @param N number of images
-        * @param C number of channels
-        * @param H height
-        * @param W width
-        * @param K number of filters
-        * @param R filter height
-        * @param S filter width
-        * @param pad_h pad height
-        * @param pad_w pad width
-        * @param stride_h stride height
-        * @param stride_w stride width
-        * @param P output activation height
-        * @param Q output activation width
+        * @param algo     cudnn algorithm wrapper
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
        private static void cudnnConv2dBackwardFilter(GPUContext gCtx, String 
instName, Pointer imagePointer, Pointer doutPointer,
-                       Pointer dwPointer, int N, int C, int H, int W, int K, 
int R,
-                       int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,
-                       int Q) throws DMLRuntimeException {
+                       Pointer dwPointer, LibMatrixCuDNNConvolutionAlgorithm 
algo) throws DMLRuntimeException {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : conv2dBackwardFilter" + ", 
GPUContext=" + gCtx);
                }
-               cudnnFilterDescriptor dwDesc = null;
-               cudnnConvolutionDescriptor convDesc = null;
-
-               Pointer workSpace = null;
-               long sizeInBytes = 0;
                try {
-
-                       long t1 = 0, t2 = 0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
-                       // Allocate descriptors
-                       cudnnTensorDescriptor xTensorDesc = 
allocateTensorDescriptor(N, C, H, W);
-                       cudnnTensorDescriptor doutTensorDesc = 
allocateTensorDescriptor(N, K, P, Q);
-                       dwDesc = allocateFilterDescriptor(K, C, R, S);
-
-                       // Allocate data
-                       int padding[] = {pad_h, pad_w};
-                       int strides[] = {stride_h, stride_w};
-                       convDesc = allocateConvolutionDescriptor(padding, 
strides);
-                       long sizeInBytesArray[] = {0};
-
-                       // TODO: Select the best algorithm depending on the 
data and supported CUDA
-                       int algo = 
jcuda.jcudnn.cudnnConvolutionBwdFilterAlgo.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
-
-                       workSpace = new Pointer();
-                       
cudnnGetConvolutionBackwardFilterWorkspaceSize(getCudnnHandle(gCtx),
-                                       xTensorDesc, doutTensorDesc, convDesc, 
dwDesc, algo, sizeInBytesArray);
+                       long t1 = GPUStatistics.DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
+                       int status = 
cudnnConvolutionBackwardFilter(getCudnnHandle(gCtx), one(), 
algo.nchwTensorDesc, imagePointer,
+                                       algo.nkpqTensorDesc, doutPointer, 
algo.convDesc, algo.algo, algo.workSpace, algo.sizeInBytes, zero(), 
algo.filterDesc, dwPointer);
                        if (GPUStatistics.DISPLAY_STATISTICS)
-                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-                       if (GPUStatistics.DISPLAY_STATISTICS) t2 = 
System.nanoTime();
-                       int status = 
cudnnConvolutionBackwardFilter(getCudnnHandle(gCtx), one(), xTensorDesc, 
imagePointer,
-                                       doutTensorDesc, doutPointer, convDesc, 
algo, workSpace, sizeInBytes, zero(), dwDesc, dwPointer);
-                       if (GPUStatistics.DISPLAY_STATISTICS)
-                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_FILTER_LIB, System.nanoTime() - 
t2);
-
+                               GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_FILTER_LIB, System.nanoTime() - 
t1);
                        if (status != 
jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
                                throw new DMLRuntimeException("Could not 
executed cudnnConvolutionBackwardFilter: " + 
jcuda.jcudnn.cudnnStatus.stringFor(status));
                        }
                } catch (CudaException e) {
                        throw new DMLRuntimeException("Error in conv2d in 
GPUContext " + gCtx.toString() + " from Thread " + 
Thread.currentThread().toString(), e);
-               } finally {
-                       long t3=0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t3 = 
System.nanoTime();
-
-                       if(workSpace != null && sizeInBytes != 0)
-                               gCtx.cudaFreeHelper(instName, workSpace);
-                       if(dwDesc != null)
-                               cudnnDestroyFilterDescriptor(dwDesc);
-
-                       if(convDesc != null)
-                               cudnnDestroyConvolutionDescriptor(convDesc);
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-               }
+               } 
        }
 
        /**
@@ -505,26 +391,32 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S; 
                long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
 
-               if(DMLScript.FORCE_ACCELERATOR ||
-                               (NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor)) {
+               if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < 
maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
                        // Filter and output are accounted as dense in the 
memory estimation for conv2dBackwardData
                        double overhead = isInSparseFormat(gCtx, filter) ? 
OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
                        overhead += isInSparseFormat(gCtx, dout) ? 
OptimizerUtils.estimateSizeExactSparsity(N, KPQ, 1.0) : 0;
                        Pointer filterPointer = getDensePointerForCuDNN(gCtx, 
filter, instName);
                        Pointer dstPointer = getDensePointerForCuDNN(gCtx, 
output, instName);
-                       if(DMLScript.FORCE_ACCELERATOR || overhead <= 
intermediateMemoryBudget) {
-                               // Perform all-input all-channel 
conv2dBackwardData
-                               Pointer doutPointer = 
getDensePointerForCuDNN(gCtx, dout, instName);
-                               cudnnConv2dBackwardData(gCtx, instName, 
filterPointer, doutPointer, dstPointer, 
-                                               N, C, H, W, K, R, S, pad_h, 
pad_w, stride_h, stride_w, P, Q);
-                       }
-                       else {
-                               InputRowFetcher doutFetcher = new 
InputRowFetcher(gCtx, instName, dout);
-                               for(int n = 0; n < N; n++) {
-                                       cudnnConv2d(gCtx, instName, 
doutFetcher.getNthRow(n), filterPointer, 
dstPointer.withByteOffset(n*CHW*Sizeof.DOUBLE), 
-                                                       1, C, H, W, K, R, S, 
pad_h, pad_w, stride_h, stride_w, P, Q);
+                       
+                       // Required for LibMatrixCuDNNConvolutionAlgorithm
+                       long workspaceLimit = (long) 
(intermediateMemoryBudget-overhead);
+                       int localN = overhead <= intermediateMemoryBudget ? N : 
1;
+                       
+                       try(LibMatrixCuDNNConvolutionAlgorithm algo = 
+                                       
LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardDataAlgorithm(gCtx,
 instName, 
+                                       localN, C, H, W, K, R, S, pad_h, pad_w, 
stride_h, stride_w, P, Q, workspaceLimit)) {
+                               if(localN == N) {
+                                       // Perform all-input all-channel 
conv2dBackwardData
+                                       Pointer doutPointer = 
getDensePointerForCuDNN(gCtx, dout, instName);
+                                       cudnnConv2dBackwardData(gCtx, instName, 
filterPointer, doutPointer, dstPointer, algo);
+                               }
+                               else {
+                                       try(LibMatrixCuDNNInputRowFetcher 
doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout)) {
+                                               for(int n = 0; n < N; n++) {
+                                                       
cudnnConv2dBackwardData(gCtx, instName, doutFetcher.getNthRow(n), 
filterPointer, dstPointer.withByteOffset(n*CHW*Sizeof.DOUBLE), algo);
+                                               }
+                                       }
                                }
-                               doutFetcher.close();
                        }
                }
                else {
@@ -540,77 +432,26 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param w pointer to filter used in conv2d
         * @param dy pointer to errors from next layer
         * @param dx pointer to  output errors
-        * @param N number of images
-        * @param C number of channels
-        * @param H height
-        * @param W width
-        * @param K number of filters
-        * @param R filter height
-        * @param S filter width
-        * @param pad_h pad height
-        * @param pad_w pad width
-        * @param stride_h stride height
-        * @param stride_w stride width
-        * @param P output activation height
-        * @param Q output activation width
+        * @param algo cudnn algorithm wrapper
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
        private static void cudnnConv2dBackwardData(GPUContext gCtx, String 
instName, Pointer w, Pointer dy,
-                       Pointer dx, int N, int C, int H, int W, int K, int R,
-                       int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,
-                       int Q) throws DMLRuntimeException {
+                       Pointer dx, LibMatrixCuDNNConvolutionAlgorithm algo) 
throws DMLRuntimeException {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : conv2dBackwardData" + ", GPUContext=" 
+ gCtx);
                }
-               cudnnFilterDescriptor wDesc = null;
-               cudnnConvolutionDescriptor convDesc = null;
-
-               Pointer workSpace = null;
-               long sizeInBytes = 0;
                try {
-                       long t1=0, t2=0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
-                       // Allocate descriptors
-                       wDesc = allocateFilterDescriptor(K, C, R, S);
-                       cudnnTensorDescriptor dyDesc = 
allocateTensorDescriptor(N, K, P, Q);
-                       cudnnTensorDescriptor dxDesc = 
allocateTensorDescriptor(N, C, H, W);
-
-                       int padding [] = { pad_h, pad_w };
-                       int strides [] = { stride_h, stride_w };
-                       convDesc = allocateConvolutionDescriptor(padding, 
strides);
-                       long sizeInBytesArray[] = { 0 };
-
-                       // TODO: Select the best algorithm depending on the 
data and supported CUDA
-                       int algo = 
jcuda.jcudnn.cudnnConvolutionBwdDataAlgo.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
-                       workSpace = new Pointer();
-                       
cudnnGetConvolutionBackwardDataWorkspaceSize(getCudnnHandle(gCtx),
-                                       wDesc, dyDesc, convDesc, dxDesc, algo, 
sizeInBytesArray);
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
-                       if (GPUStatistics.DISPLAY_STATISTICS) t2 = 
System.nanoTime();
-                       int status = 
cudnnConvolutionBackwardData(getCudnnHandle(gCtx), one(), wDesc, w,
-                                       dyDesc, dy, convDesc, algo, workSpace, 
sizeInBytes, zero(), dxDesc, dx);
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_DATA_LIB, System.nanoTime() - 
t2);
+                       long t1 = GPUStatistics.DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
+                       int status = 
cudnnConvolutionBackwardData(getCudnnHandle(gCtx), one(), algo.filterDesc, w,
+                                       algo.nkpqTensorDesc, dy, algo.convDesc, 
algo.algo, algo.workSpace, algo.sizeInBytes, zero(), algo.nchwTensorDesc, dx);
+                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_DATA_LIB, System.nanoTime() - 
t1);
 
                        if(status != 
jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
                                throw new DMLRuntimeException("Could not 
executed cudnnConvolutionBackwardData: " + 
jcuda.jcudnn.cudnnStatus.stringFor(status));
                        }
                } catch (CudaException e) {
                        throw new DMLRuntimeException("Error in conv2d in 
GPUContext " + gCtx.toString() + " from Thread " + 
Thread.currentThread().toString(), e);
-               }
-               finally {
-                       long t3=0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t3 = 
System.nanoTime();
-
-                       if(workSpace != null && sizeInBytes != 0)
-                               gCtx.cudaFreeHelper(instName, workSpace);
-                       if(wDesc != null)
-                               cudnnDestroyFilterDescriptor(wDesc);
-                       if(convDesc != null)
-                               cudnnDestroyConvolutionDescriptor(convDesc);
-
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
-               }
+               }       
        }
 
        /**
@@ -642,18 +483,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                long CHW = C*H*W; long CPQ = C*P*Q;  
                long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
-               if(DMLScript.FORCE_ACCELERATOR || 
-                               (NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < 
maxNumDoublesOfCuDNNTensor)) {
+               if(NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < 
maxNumDoublesOfCuDNNTensor) {
                        // Filter and output are accounted as dense in the 
memory estimation for conv2dBackwardData
                        long overhead = isInSparseFormat(gCtx, image) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
                        Pointer y = getDensePointerForCuDNN(gCtx, outputBlock, 
instName);
-                       if(DMLScript.FORCE_ACCELERATOR || overhead <= 
intermediateMemoryBudget) {
+                       if(overhead <= intermediateMemoryBudget) {
                                Pointer x = getDensePointerForCuDNN(gCtx, 
image, instName);
                                cudnnTensorDescriptor xDesc = 
allocateTensorDescriptor(gCtx, image, N, C, H, W);
                                cudnnMaxpooling(gCtx, instName, x, xDesc, y, N, 
C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                        }
                        else {
-                               InputRowFetcher imgFetcher = new 
InputRowFetcher(gCtx, instName, image);
+                               LibMatrixCuDNNInputRowFetcher imgFetcher = new 
LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                                cudnnTensorDescriptor xDesc = 
allocateTensorDescriptor(gCtx, image, N, C, H, W);
                                for(int n = 0; n < N; n++) {
                                        cudnnMaxpooling(gCtx, instName, 
imgFetcher.getNthRow(n), xDesc, y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, 
W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
@@ -666,57 +506,6 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                }
        }
 
-       /**
-        * Performs a slice operation: out = in[(n+1):(n+1), 1:numColumns]
-        */
-       private static class InputRowFetcher {
-               GPUContext gCtx; String instName; int numColumns; boolean 
isInputInSparseFormat; 
-               Object inPointer; // can be either CSRPointer or Pointer 
-               Pointer outPointer;
-
-               /**
-                * Initialize the input fetcher
-                * 
-                * @param gCtx current gpu context
-                * @param instName name of the instruction
-                * @param image input matrix object.
-                * @throws DMLRuntimeException if error
-                */
-               public InputRowFetcher(GPUContext gCtx, String instName, 
MatrixObject image) throws DMLRuntimeException {
-                       this.gCtx = gCtx; this.instName = instName;
-                       numColumns = toInt(image.getNumColumns());
-                       isInputInSparseFormat = isInSparseFormat(gCtx, image);
-                       inPointer = isInputInSparseFormat ? 
getSparsePointer(gCtx, image, instName) : getDensePointerForCuDNN(gCtx, image, 
instName);
-                       outPointer = gCtx.allocate(numColumns*Sizeof.DOUBLE);
-               }
-               /**
-                * Copy the nth row and return the dense pointer
-                * @param n zero-based row index
-                * @return dense pointer containing the nth row. This row is 
reused in the next iteration
-                * @throws DMLRuntimeException
-                */
-               public Pointer getNthRow(int n) throws DMLRuntimeException {
-                       if(isInputInSparseFormat) {
-                               jcuda.runtime.JCuda.cudaDeviceSynchronize();
-                               long t0 = GPUStatistics.DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
-                               cudaMemset(outPointer, 0, 
numColumns*Sizeof.DOUBLE);
-                               jcuda.runtime.JCuda.cudaDeviceSynchronize();
-                               if(GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, 
System.nanoTime() - t0);
-                               sliceSparseDense(gCtx, instName, 
(CSRPointer)inPointer, outPointer, n, n, 0, toInt(numColumns-1), numColumns);
-                       }
-                       else {
-                               sliceDenseDense(gCtx, instName, 
(Pointer)inPointer, outPointer, n, n, 0, toInt(numColumns-1), numColumns);
-                       }
-                       return outPointer;
-               }
-               /**
-                * Deallocates temporary pointer
-                */
-               public void close() {
-                       gCtx.cudaFreeHelper(outPointer, true);
-               }
-       }
-
        private static void cudnnMaxpooling(GPUContext gCtx, String instName, 
Pointer x, cudnnTensorDescriptor xDesc,
                        Pointer y, int N, int C, int H, int W, int K, int R,
                        int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,
@@ -749,7 +538,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                        long t3=0;
                        if (GPUStatistics.DISPLAY_STATISTICS) t3 = 
System.nanoTime();
                        if(poolingDesc != null)
-                               cudnnDestroyPoolingDescriptor(poolingDesc);
+                               
jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
                        if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
                }
        }
@@ -785,20 +574,19 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                long CHW = C*H*W; long CPQ = C*P*Q;  
                long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
-               if(DMLScript.FORCE_ACCELERATOR || 
-                               (NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < 
maxNumDoublesOfCuDNNTensor)) {
+               if(NCHW < maxNumDoublesOfCuDNNTensor && NCPQ < 
maxNumDoublesOfCuDNNTensor) {
                        // Filter and output are accounted as dense in the 
memory estimation for conv2dBackwardData
                        long overhead = isInSparseFormat(gCtx, image) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
                        overhead += isInSparseFormat(gCtx, dout) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CPQ, 1.0) : 0;
                        Pointer dx = getDensePointerForCuDNN(gCtx, outputBlock, 
instName);
-                       if(DMLScript.FORCE_ACCELERATOR || overhead <= 
intermediateMemoryBudget) {
+                       if(overhead <= intermediateMemoryBudget) {
                                Pointer x = getDensePointerForCuDNN(gCtx, 
image, instName);
                                Pointer dy = getDensePointerForCuDNN(gCtx, 
dout, instName);
                                cudnnMaxpoolingBackward(gCtx, instName, x, dy, 
dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                        }
                        else {
-                               InputRowFetcher imgFetcher = new 
InputRowFetcher(gCtx, instName, image);
-                               InputRowFetcher doutFetcher = new 
InputRowFetcher(gCtx, instName, dout);
+                               LibMatrixCuDNNInputRowFetcher imgFetcher = new 
LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
+                               LibMatrixCuDNNInputRowFetcher doutFetcher = new 
LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
                                for(int n = 0; n < N; n++) {
                                        cudnnMaxpoolingBackward(gCtx, instName, 
imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), 
                                                        
dx.withByteOffset(n*CHW*Sizeof.DOUBLE), 
@@ -868,13 +656,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                        if(y != null)
                                gCtx.cudaFreeHelper(instName, y);
                        if(poolingDesc != null)
-                               cudnnDestroyPoolingDescriptor(poolingDesc);
+                               
jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
 
                        if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
                }
        }
 
-       private static cudnnConvolutionDescriptor 
allocateConvolutionDescriptor(int padding [], int strides []) {
+       static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int 
padding [], int strides []) {
                cudnnConvolutionDescriptor convDesc = new 
cudnnConvolutionDescriptor();
                cudnnCreateConvolutionDescriptor(convDesc);
                cudnnSetConvolution2dDescriptor(convDesc, padding[0], 
padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
@@ -914,7 +702,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @return cudnn tensor descriptor
         * @throws DMLRuntimeException if the input descriptor and matrix 
dimensions don't match
         */
-       private static cudnnTensorDescriptor allocateTensorDescriptor(int N, 
int C, int H, int W) throws DMLRuntimeException {
+       static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int 
H, int W) throws DMLRuntimeException {
                cudnnTensorDescriptor tensorDescriptor = new 
cudnnTensorDescriptor();
                cudnnCreateTensorDescriptor(tensorDescriptor);
                cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, 
CUDNN_DATA_DOUBLE, N, C, H, W);

http://git-wip-us.apache.org/repos/asf/systemml/blob/43b573df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
new file mode 100644
index 0000000..363cf78
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.matrix.data;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+import jcuda.jcudnn.cudnnConvolutionBwdDataPreference;
+import jcuda.jcudnn.cudnnConvolutionBwdFilterPreference;
+import jcuda.jcudnn.cudnnConvolutionDescriptor;
+import jcuda.jcudnn.cudnnConvolutionFwdPreference;
+import jcuda.jcudnn.cudnnFilterDescriptor;
+import jcuda.jcudnn.cudnnTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+
+/**
+ * This class is a wrapper that contain necessary data structures to invoke 
+ * a cudnn convolution* functions (such as cudnnConvolutionForward, etc)
+ * 
+ * It implements autocloseable to simplify the LibMatrixCuDNN code and also 
avoids potential memory leaks.
+ * 
+ * The caller has to use the factory methods: 
cudnnGetConvolutionForwardAlgorithm, 
+ * cudnnGetConvolutionBackwardFilterAlgorithm and 
cudnnGetConvolutionBackwardDataAlgorithm
+ * to get the LibMatrixCuDNNConvolutionAlgorithm object.
+ * The naming of this methods is consistent with that of CuDNN library.
+ *  
+ */
+public class LibMatrixCuDNNConvolutionAlgorithm implements 
java.lang.AutoCloseable {
+       public int algo = -1;
+       public Pointer workSpace = new Pointer();
+       public long sizeInBytes = 0;
+       cudnnTensorDescriptor nchwTensorDesc = null;
+       cudnnTensorDescriptor nkpqTensorDesc = null;
+       cudnnFilterDescriptor filterDesc = null;
+       cudnnConvolutionDescriptor convDesc = null;
+       GPUContext gCtx = null; String instName = null;
+       
+       private LibMatrixCuDNNConvolutionAlgorithm(GPUContext gCtx, String 
instName, int N, int C, int H, int W, int K, int R, int S, 
+                       int pad_h, int pad_w, int stride_h, int stride_w, int 
P, int Q) throws DMLRuntimeException {
+               int padding[] = {pad_h, pad_w};
+               int strides[] = {stride_h, stride_w};
+               convDesc = 
LibMatrixCuDNN.allocateConvolutionDescriptor(padding, strides);
+               this.gCtx = gCtx;
+               this.instName = instName;
+               nchwTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, C, 
H, W);
+               nkpqTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, K, 
P, Q);
+               filterDesc = LibMatrixCuDNN.allocateFilterDescriptor(K, C, R, 
S);
+       }
+       
+       /**
+        * Deallocates the tensor and filter descriptors as well as allocated 
workspace
+        */
+       @Override
+       public void close() {
+               long t3 = 0;
+               if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
+               if(nchwTensorDesc != null)
+                       cudnnDestroyTensorDescriptor(nchwTensorDesc);
+               if(nkpqTensorDesc != null)
+                       cudnnDestroyTensorDescriptor(nkpqTensorDesc);
+               if(filterDesc != null)
+                       cudnnDestroyFilterDescriptor(filterDesc);
+               if(convDesc != null)
+                       cudnnDestroyConvolutionDescriptor(convDesc);
+               if(sizeInBytes != 0)
+                       gCtx.cudaFreeHelper(instName, workSpace);
+               if(GPUStatistics.DISPLAY_STATISTICS)
+                       GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
+       }
+       
+       /**
+        * Factory method to get the algorithm wrapper for convolution forward
+        * 
+        * @param gCtx     a valid {@link GPUContext}
+        * @param instName the invoking instruction's name for record {@link 
Statistics}.
+        * @param N        number of input images
+        * @param C        number of channels
+        * @param H        height of each image
+        * @param W        width of each image
+        * @param K        number of output "channels"
+        * @param R        height of filter
+        * @param S        width of filter
+        * @param pad_h    padding height
+        * @param pad_w    padding width
+        * @param stride_h stride height
+        * @param stride_w string width
+        * @param P        output height
+        * @param Q        output width
+        * @param workspaceLimit maximum intermediate memory to use
+        * @return algorithm wrapper
+        * @throws DMLRuntimeException if error occurs
+        */
+       public static LibMatrixCuDNNConvolutionAlgorithm 
cudnnGetConvolutionForwardAlgorithm(
+                       GPUContext gCtx, String instName, int N, int C, int H, 
int W, int K, int R, int S, 
+                       int pad_h, int pad_w, int stride_h, int stride_w, int 
P, int Q, long workspaceLimit) throws DMLRuntimeException {
+               long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() 
: 0;
+               LibMatrixCuDNNConvolutionAlgorithm ret = new 
LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, 
+                               pad_h, pad_w, stride_h, stride_w, P, Q);
+               if(workspaceLimit <= 0) {
+                       // If overhead is greater than intermediate allocated 
memory, prefer the cudnn operator with no memory requirement, 
+                       // i.e. CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
+                       ret.algo = 
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+               }
+               else {
+                       int[] algos = {-1};
+                       long sizeInBytesArray[] = {workspaceLimit};
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(LibMatrixCuDNN.getCudnnHandle(gCtx),
 
+                                       ret.nchwTensorDesc, ret.filterDesc, 
ret.convDesc, ret.nkpqTensorDesc,
+                                       
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 
sizeInBytesArray[0], algos);
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gCtx),
 
+                                       ret.nchwTensorDesc, ret.filterDesc, 
ret.convDesc, ret.nkpqTensorDesc, algos[0], sizeInBytesArray);
+                       if (sizeInBytesArray[0] != 0)
+                               ret.workSpace = 
gCtx.allocate(sizeInBytesArray[0]);
+                       ret.sizeInBytes = sizeInBytesArray[0];
+                       ret.algo = algos[0];
+               }
+               if (GPUStatistics.DISPLAY_STATISTICS)
+                       GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
+               return ret;
+       }
+       
+       /**
+        * Factory method to get the algorithm wrapper for convolution backward 
filter
+        * 
+        * @param gCtx     a valid {@link GPUContext}
+        * @param instName the invoking instruction's name for record {@link 
Statistics}.
+        * @param N        number of input images
+        * @param C        number of channels
+        * @param H        height of each image
+        * @param W        width of each image
+        * @param K        number of output "channels"
+        * @param R        height of filter
+        * @param S        width of filter
+        * @param pad_h    padding height
+        * @param pad_w    padding width
+        * @param stride_h stride height
+        * @param stride_w string width
+        * @param P        output height
+        * @param Q        output width
+        * @param workspaceLimit maximum intermediate memory to use
+        * @return algorithm wrapper
+        * @throws DMLRuntimeException if error occurs
+        */
+       public static LibMatrixCuDNNConvolutionAlgorithm 
cudnnGetConvolutionBackwardFilterAlgorithm(
+                       GPUContext gCtx, String instName, int N, int C, int H, 
int W, int K, int R, int S, 
+                       int pad_h, int pad_w, int stride_h, int stride_w, int 
P, int Q, long workspaceLimit) throws DMLRuntimeException {
+               long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() 
: 0;
+               LibMatrixCuDNNConvolutionAlgorithm ret = new 
LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, 
+                               pad_h, pad_w, stride_h, stride_w, P, Q);
+               
+               if(workspaceLimit <= 0) {
+                       // If overhead is greater than intermediate allocated 
memory, prefer the cudnn operator with no memory requirement
+                       // i.e. CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
+                       ret.algo = 
jcuda.jcudnn.cudnnConvolutionBwdFilterAlgo.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+               }
+               else {
+                       int[] algos = {-1};
+                       long sizeInBytesArray[] = {workspaceLimit};
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterAlgorithm(
+                                       LibMatrixCuDNN.getCudnnHandle(gCtx), 
+                                       ret.nchwTensorDesc, ret.nkpqTensorDesc, 
ret.convDesc, ret.filterDesc, 
+                                       
cudnnConvolutionBwdFilterPreference.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
 sizeInBytesArray[0], algos);
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gCtx),
 
+                                       ret.nchwTensorDesc, ret.nkpqTensorDesc, 
ret.convDesc, ret.filterDesc, algos[0], sizeInBytesArray);
+                       if (sizeInBytesArray[0] != 0)
+                               ret.workSpace = 
gCtx.allocate(sizeInBytesArray[0]);
+                       ret.sizeInBytes = sizeInBytesArray[0];
+                       ret.algo = algos[0];
+               }
+               if (GPUStatistics.DISPLAY_STATISTICS)
+                       GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
+               return ret;
+       }
+       
+       /**
+        * Factory method to get the algorithm wrapper for convolution backward 
data
+        * 
+        * @param gCtx     a valid {@link GPUContext}
+        * @param instName the invoking instruction's name for record {@link 
Statistics}.
+        * @param N        number of input images
+        * @param C        number of channels
+        * @param H        height of each image
+        * @param W        width of each image
+        * @param K        number of output "channels"
+        * @param R        height of filter
+        * @param S        width of filter
+        * @param pad_h    padding height
+        * @param pad_w    padding width
+        * @param stride_h stride height
+        * @param stride_w string width
+        * @param P        output height
+        * @param Q        output width
+        * @param workspaceLimit maximum intermediate memory to use
+        * @return algorithm wrapper
+        * @throws DMLRuntimeException if error occurs
+        */
+       public static LibMatrixCuDNNConvolutionAlgorithm 
cudnnGetConvolutionBackwardDataAlgorithm(
+                       GPUContext gCtx, String instName, int N, int C, int H, 
int W, int K, int R, int S, 
+                       int pad_h, int pad_w, int stride_h, int stride_w, int 
P, int Q, long workspaceLimit) throws DMLRuntimeException {
+               long t1 = GPUStatistics.DISPLAY_STATISTICS ? System.nanoTime() 
: 0;
+               LibMatrixCuDNNConvolutionAlgorithm ret = new 
LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, 
+                               pad_h, pad_w, stride_h, stride_w, P, Q);
+               
+               if(workspaceLimit <= 0) {
+                       // If overhead is greater than intermediate allocated 
memory, prefer the cudnn operator with no memory requirement
+                       // i.e. CUDNN_CONVOLUTION_BWD_DATA_ALGO_0
+                       ret.algo = 
jcuda.jcudnn.cudnnConvolutionBwdDataAlgo.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+               }
+               else {
+                       int[] algos = {-1};
+                       long sizeInBytesArray[] = {workspaceLimit};
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataAlgorithm(
+                                       LibMatrixCuDNN.getCudnnHandle(gCtx), 
+                                       ret.filterDesc, ret.nkpqTensorDesc, 
ret.convDesc, ret.nchwTensorDesc,
+                                       
cudnnConvolutionBwdDataPreference.CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
 sizeInBytesArray[0], algos);
+                       
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gCtx),
 
+                                       ret.filterDesc, ret.nkpqTensorDesc, 
ret.convDesc, ret.nchwTensorDesc, algos[0], sizeInBytesArray);
+                       if (sizeInBytesArray[0] != 0)
+                               ret.workSpace = 
gCtx.allocate(sizeInBytesArray[0]);
+                       ret.sizeInBytes = sizeInBytesArray[0];
+                       ret.algo = algos[0];
+               }
+               if (GPUStatistics.DISPLAY_STATISTICS)
+                       GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
+               return ret;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/43b573df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
new file mode 100644
index 0000000..b9619c8
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.matrix.data;
+
+import static jcuda.runtime.JCuda.cudaMemset;
+import jcuda.Pointer;
+import jcuda.Sizeof;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.utils.GPUStatistics;
+
+/**
+ * Performs a slice operation: out = in[(n+1):(n+1), 1:numColumns]
+ */
+public class LibMatrixCuDNNInputRowFetcher implements java.lang.AutoCloseable {
+       GPUContext gCtx; String instName; int numColumns; boolean 
isInputInSparseFormat; 
+       Object inPointer; // can be either CSRPointer or Pointer 
+       Pointer outPointer;
+
+       /**
+        * Initialize the input fetcher
+        * 
+        * @param gCtx current gpu context
+        * @param instName name of the instruction
+        * @param image input matrix object.
+        * @throws DMLRuntimeException if error
+        */
+       public LibMatrixCuDNNInputRowFetcher(GPUContext gCtx, String instName, 
MatrixObject image) throws DMLRuntimeException {
+               this.gCtx = gCtx; this.instName = instName;
+               numColumns = LibMatrixCuDNN.toInt(image.getNumColumns());
+               isInputInSparseFormat = LibMatrixCuDNN.isInSparseFormat(gCtx, 
image);
+               inPointer = isInputInSparseFormat ? 
LibMatrixCuDNN.getSparsePointer(gCtx, image, instName) : 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
+               outPointer = gCtx.allocate(numColumns*Sizeof.DOUBLE);
+       }
+       /**
+        * Copy the nth row and return the dense pointer
+        * @param n zero-based row index
+        * @return dense pointer containing the nth row. This row is reused in 
the next iteration
+        * @throws DMLRuntimeException
+        */
+       public Pointer getNthRow(int n) throws DMLRuntimeException {
+               if(isInputInSparseFormat) {
+                       jcuda.runtime.JCuda.cudaDeviceSynchronize();
+                       long t0 = GPUStatistics.DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
+                       cudaMemset(outPointer, 0, numColumns*Sizeof.DOUBLE);
+                       jcuda.runtime.JCuda.cudaDeviceSynchronize();
+                       if(GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_SET_ZERO, 
System.nanoTime() - t0);
+                       LibMatrixCuDNN.sliceSparseDense(gCtx, instName, 
(CSRPointer)inPointer, outPointer, n, n, 0, LibMatrixCuDNN.toInt(numColumns-1), 
numColumns);
+               }
+               else {
+                       LibMatrixCuDNN.sliceDenseDense(gCtx, instName, 
(Pointer)inPointer, outPointer, n, n, 0, LibMatrixCuDNN.toInt(numColumns-1), 
numColumns);
+               }
+               return outPointer;
+       }
+       /**
+        * Deallocates temporary pointer
+        */
+       @Override
+       public void close() {
+               gCtx.cudaFreeHelper(outPointer, true);
+       }
+}
\ No newline at end of file

Reply via email to