Repository: systemml
Updated Branches:
  refs/heads/master ef842da9c -> bd34292d4


http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 4ad4155..0424114 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -38,6 +38,15 @@ import org.apache.sysml.runtime.util.DnnUtils;
 import org.apache.sysml.utils.GPUStatistics;
 
 public class DnnGPUInstruction extends GPUInstruction {
+       
+       public static enum LstmOperator {
+               CUDNN,
+               DENSE_NN,
+               NONE
+       }
+       
+       public static LstmOperator FORCED_LSTM_OP = LstmOperator.NONE;
+       
        private CPOperand _input1;
        private CPOperand _input2;
        private CPOperand _input3;
@@ -638,43 +647,36 @@ public class DnnGPUInstruction extends GPUInstruction {
                return (int)num;
        }
        
+       public static long getMemRequiredForCuDNNLSTMBackward(long N, long T, 
long M, long D, boolean return_sequences) {
+               double memRequired = (D+M)*4*M // sysmlWPointer
+                               + 2*(D+M+2)*(4*M) // cudnnWPointer and 
cudnnDwPointer
+                               + 3*N*T*D  // cudnnInput, cudnnDx and smlDx
+                               + 2*N*T*M // dy and yPointer
+                               + (return_sequences ? T*M : M); // dout
+               memRequired *= LibMatrixCUDA.sizeOfDataType;
+               // Assume the workspace to be proportional to cudnnWPointer 
(add 20% additional overhead for workspace)
+               memRequired += 1.2*(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType;
+               return (long)memRequired;
+       }
+       
        private void processLstmBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
                long M = out0.getNumColumns(); // hiddenSize .. since out0: (N, 
M)
+               long N1 = out0.getNumRows();
                Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instName);
                
                MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
                MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
                long numRowsW = W.getNumRows();
-               long D = numRowsW - M; // since W:(D+M, 4M) ... numFeatures 
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-                               sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               
-               
+               long D = numRowsW - M; // since W:(D+M, 4M) ... numFeatures
                MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instName); 
                int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
                long numColsX = X.getNumColumns();
                int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-                               xPointer, cudnnInput, N, D, T*D, N*T*D);
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
                boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
                
-               // LibMatrixCuDNN.lstm(ec, gCtx, instName, 
-                               // cudnnInput, cudnnWPointer, out0Pointer, 
c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
-                               // String xName, Pointer hx, Pointer cx, 
Pointer wPointer, String doutName, String dcyName,  // input
-                               // String dxName, String dwName, String dbName, 
String dhxName, String dcxName,         // output
+               // long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, 
M, D, return_sequences);
+                
                String dxName = _output.getName();
                String dwName = _output2.getName();
                String dbName = _output3.getName();
@@ -682,12 +684,95 @@ public class DnnGPUInstruction extends GPUInstruction {
                String dcxName = _output5.getName();
                String doutName = _input7.getName();
                String dcyName = _input8.getName();
-               LibMatrixCuDNN.lstmBackward(ec, gCtx, instName, 
-                               cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
-                               dxName, dwName, dbName, dhxName, dcxName, // 
output 
-                               return_sequences, N, M, D, T);
-               gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+               
+               long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, 
D, return_sequences);
+               
+               boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
+               
+               
+               
+               if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
+                       N != N1 || // Use CuDNN operator when batch size of 
previous iteration is different that current iteration
+                       (!isWSparse && // Don't use CuDNN kernel when w is 
sparse.
+                       // When an operator is not forced, then prefer CuDNN 
kernel if it can fit in the GPU memory
+                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
+                       // Use CuDNN LSTM kernel
+                       Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                       Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+                                       sysmlWPointer, sysmlBiasPointer, 
cudnnWPointer, D, M);
+                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
+                       Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+                                       xPointer, cudnnInput, N, D, T*D, N*T*D);
+                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
+                       LibMatrixCuDNN.cuDNNLstmBackward(ec, gCtx, instName, 
+                                       cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
+                                       dxName, dwName, dbName, dhxName, 
dcxName, // output 
+                                       return_sequences, N, M, D, T);
+                       gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
+                       gCtx.cudaFreeHelper(instName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+               }
+               else {
+                       if(N != N1) {
+                               throw new DMLRuntimeException("Unsupported 
operation: The batch size of previous iteration " + N1 + 
+                                               " is different than the batch 
size of current iteration " + N);
+                       }
+                       
+                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
+                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
+                       
+                       Pointer doutPointer = 
LibMatrixCuDNN.getDenseInputPointer(ec, gCtx, instName, doutName, N, 
return_sequences ? T*M : M);
+                       Pointer dcyPointer = 
LibMatrixCuDNN.getDenseInputPointer(ec, gCtx, instName, dcyName, N, M);
+                       
+                       Pointer dxPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D);
+                       Pointer dwPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M);
+                       Pointer dbPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M);
+                       Pointer dhxPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M);
+                       Pointer dcxPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M);
+                       
+                       // Donot skip cache as it is required in the backward 
pass
+                       Pointer cache_out = gCtx.allocate(instName, 
T*N*M*LibMatrixCUDA.sizeOfDataType);
+                       Pointer cache_c = gCtx.allocate(instName, 
T*N*M*LibMatrixCUDA.sizeOfDataType);
+                       Pointer cache_ifog = gCtx.allocate(instName, 
T*N*4*M*LibMatrixCUDA.sizeOfDataType);
+                       
+                       Pointer cyPointer = gCtx.allocate(instName, 
N*M*LibMatrixCUDA.sizeOfDataType);
+                       Pointer sysmlYPointer = gCtx.allocate(instName, 
(return_sequences ? N*(T*M) : N*M)*LibMatrixCUDA.sizeOfDataType);
+                       LibMatrixCuDNN.nnLstm(ec, gCtx, instName, xPointer, W, 
sysmlBiasPointer, out0Pointer, 
+                                       c0Pointer, return_sequences, 
sysmlYPointer, cyPointer, 
+                                       cache_out, cache_c, cache_ifog, 
+                                       N, M,  D, T);
+                       gCtx.cudaFreeHelper(instName, sysmlYPointer, 
gCtx.EAGER_CUDA_FREE);
+                       gCtx.cudaFreeHelper(instName, cyPointer, 
gCtx.EAGER_CUDA_FREE);
+                       
+                       LibMatrixCuDNN.nnLstmBackward(ec, gCtx, instName,
+                                       xPointer, out0Pointer, c0Pointer, W, 
doutPointer, dcyPointer,  // input
+                                       cache_out, cache_c, cache_ifog,
+                                       dxPointer, dwPointer, dbPointer, 
dhxPointer, dcxPointer,        // output
+                                       return_sequences, N, M, D, T);
+                       
+                       gCtx.cudaFreeHelper(instName, cache_out, 
gCtx.EAGER_CUDA_FREE);
+                       gCtx.cudaFreeHelper(instName, cache_c, 
gCtx.EAGER_CUDA_FREE);
+                       gCtx.cudaFreeHelper(instName, cache_ifog, 
gCtx.EAGER_CUDA_FREE);
+                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
+                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
+                       ec.releaseMatrixInputForGPUInstruction(doutName);
+                       ec.releaseMatrixInputForGPUInstruction(dcyName);
+                       ec.releaseMatrixOutputForGPUInstruction(dxName);
+                       ec.releaseMatrixOutputForGPUInstruction(dwName);
+                       ec.releaseMatrixOutputForGPUInstruction(dbName);
+                       ec.releaseMatrixOutputForGPUInstruction(dhxName);
+                       ec.releaseMatrixOutputForGPUInstruction(dcxName);
+                       
+               }
                
                // release inputs/outputs
                ec.releaseMatrixInputForGPUInstruction(_input4.getName());
@@ -702,42 +787,79 @@ public class DnnGPUInstruction extends GPUInstruction {
                // out: (N, T*M) or (N, M) ==> (T, M, N)
                MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
                long M = out0.getNumColumns(); // hiddenSize .. since out0: (N, 
M)
+               long N1 = out0.getNumRows();
                Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instName);
                
                MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
                MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
                long numRowsW = W.getNumRows();
                long D = numRowsW - M; // since W:(D+M, 4M) ... numFeatures
-               
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-                               sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               
-               boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-               
-               // Beause the matrices are released immediately, the output for 
transpose need not be taken into account
                MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instName); 
-               int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
+               long N = X.getNumRows(); // batchSize .. since X:(N, T*D)
                long numColsX = X.getNumColumns();
-               int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-                               xPointer, cudnnInput, N, D, T*D, N*T*D);
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               long T = numColsX/D; // since X:(N, T*D) ... seqLength
+               boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
+               
+               long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, 
D, return_sequences);
                
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
+               boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
                
-               LibMatrixCuDNN.lstm(ec, gCtx, instName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), 
-                               toInt(N), toInt(M), toInt(D), toInt(T));
-               gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+               if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
+                       N != N1 || // Use CuDNN operator when batch size of 
previous iteration is different that current iteration
+                       (!isWSparse && // Don't use CuDNN kernel when w is 
sparse.
+                       // When an operator is not forced, then prefer CuDNN 
kernel if it can fit in the GPU memory
+                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
+                       Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                       Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+                                       sysmlWPointer, sysmlBiasPointer, 
cudnnWPointer, toInt(D), toInt(M));
+                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
+                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
+                       // Beause the matrices are released immediately, the 
output for transpose need not be taken into account
+                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
+                       Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+                                       xPointer, cudnnInput, toInt(N), 
toInt(D), toInt(T*D), toInt(N*T*D));
+                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
+                       LibMatrixCuDNN.cuDNNLstm(ec, gCtx, instName, 
cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, 
_output.getName(), _output2.getName(), 
+                                       toInt(N), toInt(M), toInt(D), toInt(T));
+                       gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
+                       gCtx.cudaFreeHelper(instName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+               }
+               else {
+                       if(N != N1) {
+                               throw new DMLRuntimeException("Unsupported 
operation: The batch size of previous iteration " + N1 + 
+                                               " is different than the batch 
size of current iteration " + N);
+                       }
+                       
+                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
+                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
+                       Pointer sysmlYPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, _output.getName(), N, 
+                                       return_sequences ? (T*M) : M);
+                       Pointer cyPointer = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName,  _output2.getName(), 
N, M);
+                       
+                       // Skip cache in forward for now. We can revisit this 
when we add stateful operators.
+                       Pointer cache_out = null; // gCtx.allocate(instName, 
T*N*M*LibMatrixCUDA.sizeOfDataType);
+                       Pointer cache_c = null;  // gCtx.allocate(instName, 
T*N*M*LibMatrixCUDA.sizeOfDataType);
+                       Pointer cache_ifog = null; // gCtx.allocate(instName, 
T*N*4*M*LibMatrixCUDA.sizeOfDataType);
+                       
+                       LibMatrixCuDNN.nnLstm(ec, gCtx, instName, xPointer, W, 
sysmlBiasPointer, out0Pointer, 
+                                       c0Pointer, return_sequences, 
sysmlYPointer, cyPointer, 
+                                       cache_out, cache_c, cache_ifog, 
+                                       N, M,  D, T);
+                       
+                       // gCtx.cudaFreeHelper(instName, cache_out, 
gCtx.EAGER_CUDA_FREE);
+                       // gCtx.cudaFreeHelper(instName, cache_c, 
gCtx.EAGER_CUDA_FREE);
+                       // gCtx.cudaFreeHelper(instName, cache_ifog, 
gCtx.EAGER_CUDA_FREE);
+                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
+                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
+               }
                
                // release inputs/outputs
                ec.releaseMatrixInputForGPUInstruction(_input4.getName());

http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index a08d4fd..6a04d97 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -224,6 +224,10 @@ public class GPUMemoryManager {
                        return "->" + stackTrace[index].getClassName() + "." + 
stackTrace[index].getMethodName() + "(" + stackTrace[index].getFileName() + ":" 
+ stackTrace[index].getLineNumber() + ")";
        }
        
+       public boolean canAllocate(String opcode, long size) {
+               return allocator.canAllocate(size);
+       }
+       
        
        public boolean canAllocateWithoutEviction(String opcode, long size) {
                return lazyCudaFreeMemoryManager.contains(opcode, size) || 
allocator.canAllocate(size) ||

http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 00aa578..fd06578 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -227,6 +227,23 @@ public class LibMatrixCUDA {
                                A, ret, numElems);
                return ret;
        }
+       
+       public static void printPointerForDebugging(Pointer ptr, int rows, int 
cols, String matName) {
+               if(sizeOfDataType == jcuda.Sizeof.DOUBLE) {
+                       double[] devData = new double[rows*cols];
+                       cudaMemcpy(Pointer.to(devData), ptr, 
rows*cols*sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost);
+                       System.out.println(matName + ":");
+                       for(int i = 0; i < rows; i++) {
+                               for(int j = 0; j < cols; j++) {
+                                       System.out.print(String.format("%.3f", 
devData[i*cols+j]) + " ");
+                               }
+                               System.out.println();
+                       }
+               }
+               else {
+                       throw new DMLRuntimeException("The method 
printPointerForDebugging is only supported for double precision.");
+               }
+       }
 
        //********************************************************************/
        //************************ End of UTILS ******************************/
@@ -1425,7 +1442,7 @@ public class LibMatrixCUDA {
         * @param isRightTransposed true if right matrix is transposed
         * @param op                operator
         */
-       private static void matrixMatrixOp(ExecutionContext ec, GPUContext 
gCtx, String instName, MatrixObject in1, MatrixObject in2,
+       static void matrixMatrixOp(ExecutionContext ec, GPUContext gCtx, String 
instName, MatrixObject in1, MatrixObject in2,
                        String outputName, boolean isLeftTransposed, boolean 
isRightTransposed, BinaryOperator op) {
                if (ec.getGPUContext(0) != gCtx)
                        throw new DMLRuntimeException("GPU : Invalid internal 
state, the GPUContext set with the ExecutionContext is not the same used to run 
this LibMatrixCUDA function");
@@ -1502,7 +1519,7 @@ public class LibMatrixCUDA {
         * @param c                                             output matrix 
of size (maxRlen, maxClen) allocated on GPU
         * @param op                                    the operation to perform
         */
-       private static void matrixMatrixOp(GPUContext gCtx, String instName, 
Pointer a, Pointer b, int maxRlen, int maxClen, int vecStatusA, int vecStatusB, 
Pointer c, BinaryOperator op) {
+       static void matrixMatrixOp(GPUContext gCtx, String instName, Pointer a, 
Pointer b, int maxRlen, int maxClen, int vecStatusA, int vecStatusB, Pointer c, 
BinaryOperator op) {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : matrix_matrix_cellwise_op" + ", 
GPUContext=" + gCtx);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 8051cbc..413c550 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -54,11 +54,14 @@ import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import 
org.apache.sysml.runtime.matrix.data.LibMatrixCuMatMult.CuMatMultParameters;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.utils.GPUStatistics;
 import org.apache.sysml.utils.Statistics;
 
@@ -846,19 +849,231 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                }
        }
        
-       static Pointer getDenseInputPointer(ExecutionContext ec, GPUContext 
gCtx, String instName, String inputName,
+       public static Pointer getDenseInputPointer(ExecutionContext ec, 
GPUContext gCtx, String instName, String inputName,
                        long numRows, long numCols) throws DMLRuntimeException {
                MatrixObject output = 
ec.getMatrixInputForGPUInstruction(inputName, instName);
                return LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, 
instName, numRows, numCols);
        }
        
-       static Pointer getDenseOutputPointer(ExecutionContext ec, GPUContext 
gCtx, String instName, String outputName,
+       public static Pointer getDenseOutputPointer(ExecutionContext ec, 
GPUContext gCtx, String instName, String outputName,
                        long numRows, long numCols) throws DMLRuntimeException {
                MatrixObject output = ec.getMatrixObject(outputName);
                getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, 
numRows, numCols); // Allocated the dense output matrix
                return getDensePointerForCuDNN(gCtx, output, instName, numRows, 
numCols);
        }
        
+       public static void nnLstmBackward(ExecutionContext ec, GPUContext gCtx, 
String instName,
+                       Pointer X, Pointer out0, Pointer c0, MatrixObject W, 
Pointer dout, Pointer dc,  // input
+                       Pointer cache_out, Pointer cache_c, Pointer cache_ifog,
+                       Pointer dX, Pointer dW, Pointer db, Pointer dout0, 
Pointer dc0,         // output
+                       boolean return_sequences, long N, long M, long D, long 
T) throws DMLRuntimeException {
+               Pointer input = gCtx.allocate(instName, 
N*(D+M)*sizeOfDataType); 
+               Pointer difog_raw = gCtx.allocate(instName, 
N*4*M*sizeOfDataType);
+               Pointer dct = copy(gCtx, instName, dc, N*M);
+               Pointer dinput = gCtx.allocate(instName, 
N*(D+M)*sizeOfDataType); // (N, D+M)
+               Pointer tmpDb = gCtx.allocate(instName, 4*M*sizeOfDataType); // 
(1, 4M)
+               
+               // dW = dW + t(input) %*% difog_raw  # shape (D+M, 4M)
+               CuMatMultParameters param1 = new CuMatMultParameters(N, D+M,
+                               N, 4*M, true, false, one(), one());
+               
+               // dinput = difog_raw %*% t(W)  # shape (N, D+M)
+               CuMatMultParameters param2 = new CuMatMultParameters(N, 4*M,
+                               D+M, 4*M, false, true);
+               
+               CSRPointer wSparsePointer = null;
+               Pointer wDensePointer = null;
+               
+               // TODO: Only dense weight supported for now
+               boolean isWSparse = false; // isInSparseFormat(gCtx, W);
+               if(isWSparse)
+                       wSparsePointer = 
W.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
+               else
+                       wDensePointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+               
+               Pointer dout_t = return_sequences ? gCtx.allocate(instName, 
N*M*sizeOfDataType) : copy(gCtx, instName, dout, N*M);
+               if(return_sequences) {
+                       
getCudaKernels(gCtx).launchKernel("initializeDoutWhenReturnSeq",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*M)),
+                                       dout, dout_t, T-1, toInt(M), 
toInt(T*M), toInt(N*M));
+               }
+               
+               for(int t = toInt(T); t >= 1; t--) {
+                       // if (t == 1) { out_prev = out0; } else { out_prev = 
matrix(cache_out[t-1,], rows=N, cols=M) }
+                       Pointer out_prev = (t == 1) ? out0 : 
cache_out.withByteOffset((t-2)*N*M*sizeOfDataType); // since read-only
+                       
+                       // X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+                       // input = cbind(X_t, out_prev)  # shape (N, D+M)
+                       getCudaKernels(gCtx).launchKernel("prepareInputNNLstm",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*(D+M))),
+                                       X, out_prev, input, (t-1), toInt(M), 
toInt(D), toInt(T*D), toInt(D+M), toInt(N*(D+M)));
+                       
+                       // ct = matrix(cache_c[t,], rows=N, cols=M)  # shape 
(N, M)
+                       Pointer ct = 
cache_c.withByteOffset((t-1)*N*M*sizeOfDataType); // since read-only
+                       
+                       // ifog = matrix(cache_ifog[t,], rows=N, cols=4*M)
+                       Pointer ifog = 
cache_ifog.withByteOffset((t-1)*N*4*M*sizeOfDataType); // since read-only
+                       
+                       // i = ifog[,1:M]  # input gate, shape (N, M)
+                       // f = ifog[,M+1:2*M]  # forget gate, shape (N, M)
+                       // o = ifog[,2*M+1:3*M]  # output gate, shape (N, M)
+                       // g = ifog[,3*M+1:4*M]  # g gate, shape (N, M)
+                       // dct = dct + o*tanh::backward(dout_t, ct)  # shape 
(N, M)
+                       // do = tanh::forward(ct) * dout_t  # output gate, 
shape (N, M)
+                       // df = c_prev * dct  # forget gate, shape (N, M)
+                       // dc_prev = f * dct  # shape (N, M)
+                       // di = g * dct  # input gate, shape (N, M)
+                       // dg = i * dct  # g gate, shape (N, M)
+                       // di_raw = i * (1-i) * di
+                       // df_raw = f * (1-f) * df
+                       // do_raw = o * (1-o) * do
+                       // dg_raw = (1-g^2) * dg
+                       // difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw)  # 
shape (N, 4M)
+                       getCudaKernels(gCtx).launchKernel("computeDifog_raw",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*M)),
+                                       ifog, ct, dout_t, cache_c, c0, 
+                                       difog_raw, dct, dc0, // output
+                                       return_sequences ? 1 : 0, t-1, 
toInt(T), toInt(M), toInt(N*M));
+                       
+                       // dW = dW + t(input) %*% difog_raw  # shape (D+M, 4M)
+                       
LibMatrixCuMatMult.denseDenseMatMult(gCtx.getCublasHandle(), instName, dW, 
input, difog_raw, param1);
+                       
+                       // dinput = difog_raw %*% t(W)  # shape (N, D+M)
+                       if(isWSparse) {
+                               if(wSparsePointer.nnz == 0) {
+                                       cudaMemset(dinput, 0, 
N*(D+M)*sizeOfDataType);
+                               }
+                               else {
+                                       
LibMatrixCuMatMult.denseSparseMatMult(gCtx.getCusparseHandle(), instName, 
dinput, difog_raw, wSparsePointer, param2);
+                               }
+                       }
+                       else
+                               
LibMatrixCuMatMult.denseDenseMatMult(gCtx.getCublasHandle(), instName, dinput, 
difog_raw, wDensePointer, param2);
+                       
+                       // db = db + colSums(difog_raw)  # shape (1, 4M)
+                       reduceCol(gCtx, instName, "reduce_col_sum", difog_raw, 
tmpDb, 1, toInt(4*M));
+                       matrixMatrixOp(gCtx, instName, tmpDb, db, 1, 
toInt(4*M), VectorShape.NONE.code(), VectorShape.NONE.code(), db, 
+                                       new 
BinaryOperator(Plus.getPlusFnObject()));
+                       
+                       // jcuda.runtime.JCuda.cudaDeviceSynchronize();
+                       
+                       int size = toInt(Math.max(N*D, N*M));
+                       
getCudaKernels(gCtx).launchKernel("postProcessNNLstmBackward",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(size),
+                                       dinput, dout0, dout, dout_t, dX, 
return_sequences ? 1 : 0, t-1, N, D, M, 
+                                       toInt(N*D), toInt(N*M), toInt(T*D), 
toInt(T*M), toInt(D+M), size);
+                       
+               }
+               
+               gCtx.cudaFreeHelper(instName, dout_t, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, input, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, difog_raw, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, dct, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, dinput, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, tmpDb, gCtx.EAGER_CUDA_FREE);
+               
+       }
+       
+       public static void nnLstm(ExecutionContext ec, GPUContext gCtx, String 
instName,
+                       Pointer X,  MatrixObject W, Pointer b, Pointer out0, 
Pointer c0, boolean return_sequences,
+                       Pointer out, Pointer c,  // output matrices
+                       Pointer cache_out, Pointer cache_c, Pointer cache_ifog, 
// temporary workspace passed to the backward function
+                       long N, long M, long D, long T) throws 
DMLRuntimeException {
+               boolean skipCache = cache_out == null || cache_c == null || 
cache_ifog == null;
+               
+               if( (skipCache && (cache_out != null || cache_c != null || 
cache_ifog != null)) || 
+                       (!skipCache && (cache_out == null || cache_c == null || 
cache_ifog == null))) {
+                       throw new DMLRuntimeException("Either all cache 
pointers should be null or all should be not null");
+               }
+               
+               // out_prev = out0
+               Pointer out_prev = copy(gCtx, instName, out0, N*M);
+               // c_prev = c0
+               Pointer c_prev = copy(gCtx, instName, c0, N*M);
+               // c = c_prev
+               cudaMemcpy(c, c_prev, N*M*sizeOfDataType, 
cudaMemcpyDeviceToDevice);
+               
+               Pointer input = gCtx.allocate(instName, N*(D+M)*sizeOfDataType);
+               Pointer ifog = gCtx.allocate(instName, N*4*M*sizeOfDataType);
+               
+               boolean isWSparse = isInSparseFormat(gCtx, W);
+               CSRPointer wSparsePointer = null;
+               Pointer wDensePointer = null;
+               if(isWSparse)
+                       wSparsePointer = 
W.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
+               else
+                       wDensePointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+               
+               for(int t = 1; t <= T; t++) {
+                       // X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+                       // input = cbind(X_t, out_prev)  # shape (N, D+M)
+                       getCudaKernels(gCtx).launchKernel("prepareInputNNLstm",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*(D+M))),
+                                       X, out_prev, input, (t-1), toInt(M), 
toInt(D), toInt(T*D), toInt(D+M), toInt(N*(D+M)));
+                       
+                       // ifog = input %*% W
+                       CuMatMultParameters param = new CuMatMultParameters(N, 
D+M,
+                                       D+M, 4*M, false, false);
+                       if(isWSparse) {
+                               if(wSparsePointer.nnz == 0) {
+                                       cudaMemset(ifog, 0, 
N*4*M*sizeOfDataType);
+                               }
+                               else {
+                                       
LibMatrixCuMatMult.denseSparseMatMult(gCtx.getCusparseHandle(), instName, ifog, 
input, wSparsePointer, param);
+                               }
+                       }
+                       else
+                               
LibMatrixCuMatMult.denseDenseMatMult(gCtx.getCublasHandle(), instName, ifog, 
input, wDensePointer, param);
+                       
+                       // ifog = ifog + b
+                       // ifog[,1:3*M] = sigmoid::forward(ifog[,1:3*M])  # 
i,f,o gates squashed with sigmoid
+                       // ifog[,3*M+1:4*M] = tanh::forward(ifog[,3*M+1:4*M])  
# g gate squashed with tanh
+                       getCudaKernels(gCtx).launchKernel("squashIFOG",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*4*M)),
+                               ifog, b, toInt(M), toInt(N*4*M));
+                       
+                       
+                       // c = ifog[,M+1:2*M]*c_prev + 
ifog[,1:M]*ifog[,3*M+1:4*M]
+                       // out_t = ifog[,2*M+1:3*M] * tanh::forward(c)
+                       // if (return_sequences) {
+                       //   out[,(t-1)*M+1:t*M] = out_t
+                       // }
+                       // else {
+                       //   out = out_t
+                       // }
+                       // out_prev = out_t
+                       // c_prev = c
+                       // cache_out[t,] = matrix(out_t, rows=1, cols=N*M)
+                       // cache_c[t,] = matrix(c, rows=1, cols=N*M)
+                       if(skipCache) {
+                               
getCudaKernels(gCtx).launchKernel("postProcessNNLstmForwardSkipCache",
+                                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*M)),
+                                       ifog, c,  out_prev, c_prev, out,
+                                       return_sequences ? 1 : 0, t-1, 
toInt(T), toInt(M), toInt(N*M));
+                       }
+                       else {
+                               
getCudaKernels(gCtx).launchKernel("postProcessNNLstmForward",
+                                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*M)),
+                                       ifog, c,  out_prev, c_prev, out, 
cache_out, cache_c,
+                                       return_sequences ? 1 : 0, t-1, 
toInt(T), toInt(M), toInt(N*M));
+                               
+                               // cache_ifog[t,] = matrix(ifog, rows=1, 
cols=N*4*M)  # reshape
+                               
cudaMemcpy(cache_ifog.withByteOffset((t-1)*N*4*M*sizeOfDataType), ifog, 
N*4*M*sizeOfDataType, cudaMemcpyDeviceToDevice);
+                       }
+               }
+               
+               gCtx.cudaFreeHelper(instName, out_prev, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, c_prev, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, input, gCtx.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instName, ifog, gCtx.EAGER_CUDA_FREE);
+       }
+       
+       private static Pointer copy(GPUContext gCtx, String instName, Pointer 
ptr, long numElems) {
+               Pointer ret = gCtx.allocate(instName, numElems*sizeOfDataType);
+               cudaMemcpy(ret, ptr, numElems*sizeOfDataType, 
cudaMemcpyDeviceToDevice);
+               return ret;
+       }
+       
        /**
         * Computes the forward pass for an LSTM layer with M neurons.
         * The input data has N sequences of T examples, each with D features.
@@ -879,13 +1094,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param T sequence length
         * @throws DMLRuntimeException if error
         */
-       public static void lstm(ExecutionContext ec, GPUContext gCtx, String 
instName,
+       public static void cuDNNLstm(ExecutionContext ec, GPUContext gCtx, 
String instName,
                        Pointer X,  Pointer wPointer, Pointer out0, Pointer c0, 
boolean return_sequences,
                        String outputName, String cyName, int N, int M, int D, 
int T) throws DMLRuntimeException {
-               singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
+               cuDNNSingleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
        }
        
-       private static void 
singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, 
String instName,
+       private static void 
cuDNNSingleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, 
String instName,
                        Pointer x, Pointer hx, Pointer cx, Pointer wPointer,  
// input
                        String outputName, String cyName,                       
                 // output
                        String rnnMode, boolean return_sequences, int N, int M, 
int D, int T) throws DMLRuntimeException {
@@ -924,13 +1139,20 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                gCtx.cudaFreeHelper(instName, cudnnYPointer, 
gCtx.EAGER_CUDA_FREE);
        }
        
-       public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, 
String instName,
+       public static void cuDNNLstmBackward(ExecutionContext ec, GPUContext 
gCtx, String instName,
                        Pointer x, Pointer hx, Pointer cx, Pointer wPointer, 
String doutName, String dcyName,  // input
                        String dxName, String dwName, String dbName, String 
dhxName, String dcxName,    // output
                        boolean return_sequences, long N, long M, long D, long 
T) throws DMLRuntimeException {
                
                if(LOG.isDebugEnabled()) {
-                       long memRequired = (N*T*M + (return_sequences ? T*M : 
M) + N*T*M + 2*N*T*D + (D+M+2)*(4*M))*sizeOfDataType;
+                       long memRequired = (D+M)*4*M // sysmlWPointer
+                                       + 2*(D+M+2)*(4*M) // cudnnWPointer and 
cudnnDwPointer
+                                       + 3*N*T*D  // cudnnInput, cudnnDx and 
smlDx
+                                       + 2*N*T*M // dy and yPointer
+                                       + (return_sequences ? T*M : M); // dout
+                       memRequired *= LibMatrixCUDA.sizeOfDataType;
+                       // Assume the workspace to be proportional to 
cudnnWPointer
+                       // memRequired += 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType;
                        LOG.debug("Memory required for invoking lstmBackward is 
" + memRequired + " bytes + workspace + reserve space + memory for 
descriptors.");
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
index 9833456..6dacf28 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
@@ -45,9 +45,9 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
 
        private static final Log LOG = 
LogFactory.getLog(LibMatrixCuMatMult.class.getName());
 
-       private static class CuMatMultParameters {
+       public static class CuMatMultParameters {
                /*
-                * For the operation, C = op(A) %*% op(B), the below parameters 
are used
+                * For the operation, C = alpha * op(A) %*% op(B) + beta*C, the 
below parameters are used
                 * to invoke the corresponding kernels in CuBLAS and CuSPARSE.
                 * 
                 * All the below values have to be valid or else this class has 
to throw
@@ -68,8 +68,16 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
                public long rightNumCols; // number of cols of B
                private boolean isLeftTransposed; // is op(A) = t(A)
                private boolean isRightTransposed; // is op(B) = t(B)
+               private Pointer alpha = one();
+               private Pointer beta = zero();
 
                public CuMatMultParameters(long leftNumRows1, long 
leftNumCols1, long rightNumRows1, long rightNumCols1,
+                               boolean isLeftTransposed1, boolean 
isRightTransposed1, Pointer alpha1, Pointer beta1) {
+                       this(leftNumRows1, leftNumCols1, rightNumRows1, 
rightNumCols1, isLeftTransposed1, isRightTransposed1);
+                       alpha = alpha1;
+                       beta = beta1;
+               }
+               public CuMatMultParameters(long leftNumRows1, long 
leftNumCols1, long rightNumRows1, long rightNumCols1,
                                boolean isLeftTransposed1, boolean 
isRightTransposed1) {
                        leftNumRows = leftNumRows1;
                        leftNumCols = leftNumCols1;
@@ -281,7 +289,7 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
                        // Transpose: C = t(output)
                        long t0 = 
ConfigurationManager.isFinegrainedStatistics() ? System.nanoTime() : 0;
                        cudaSupportFunctions.cublasgeam(gCtx.getCublasHandle(), 
cublasOperation.CUBLAS_OP_T, cublasOperation.CUBLAS_OP_T,
-                                       toInt(outCLen), toInt(outRLen), one(), 
output, toInt(outRLen), zero(), new Pointer(),
+                                       toInt(outCLen), toInt(outRLen), 
params.alpha, output, toInt(outRLen), params.beta, new Pointer(),
                                        toInt(outRLen), C, toInt(outCLen));
                        if (!gCtx.EAGER_CUDA_FREE)
                                JCuda.cudaDeviceSynchronize();
@@ -310,7 +318,7 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
         * @param param
         *            BLAS parameters
         */
-       private static void denseSparseMatMult(cusparseHandle handle, String 
instName, Pointer C, Pointer A, CSRPointer B,
+       static void denseSparseMatMult(cusparseHandle handle, String instName, 
Pointer C, Pointer A, CSRPointer B,
                        CuMatMultParameters param) {
                long t0 = ConfigurationManager.isFinegrainedStatistics() ? 
System.nanoTime() : 0;
                String kernel = 
GPUInstruction.MISC_TIMER_SPARSE_MATRIX_DENSE_MATRIX_LIB;
@@ -322,8 +330,8 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
                        int m = toInt(param.rightNumRows);
                        int n = toInt(param.rightNumCols);
                        int transa = 
reverseCusparseOp(cusparseOp(param.isLeftTransposed));
-                       cudaSupportFunctions.cusparsecsrmv(handle, transa, m, 
n, toInt(B.nnz), one(), B.descr, B.val, B.rowPtr, B.colInd, A,
-                                       zero(), C);
+                       cudaSupportFunctions.cusparsecsrmv(handle, transa, m, 
n, toInt(B.nnz), param.alpha, B.descr, B.val, B.rowPtr, B.colInd, A,
+                                       param.beta, C);
                        kernel = 
GPUInstruction.MISC_TIMER_SPARSE_MATRIX_DENSE_VECTOR_LIB;
                } else {
                        int m = toInt(param.rightNumRows);
@@ -333,8 +341,8 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
                        int transa = 
reverseCusparseOp(cusparseOp(param.isLeftTransposed));
                        int transb = cusparseOp(param.isRightTransposed);
                        LOG.debug(" GPU Sparse-Dense Matrix Multiply (rhs 
transpose) ");
-                       cudaSupportFunctions.cusparsecsrmm2(handle, transa, 
transb, m, param.n, k, toInt(B.nnz), one(), B.descr, B.val,
-                                       B.rowPtr, B.colInd, A, param.ldb, 
zero(), C, param.ldc);
+                       cudaSupportFunctions.cusparsecsrmm2(handle, transa, 
transb, m, param.n, k, toInt(B.nnz), param.alpha, B.descr, B.val,
+                                       B.rowPtr, B.colInd, A, param.ldb, 
param.beta, C, param.ldc);
                }
                if (ConfigurationManager.isFinegrainedStatistics())
                        GPUStatistics.maintainCPMiscTimes(instName, kernel, 
System.nanoTime() - t0);
@@ -359,7 +367,7 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
         * @param param
         *            BLAS parameters
         */
-       private static void denseDenseMatMult(cublasHandle handle, String 
instName, Pointer C, Pointer A, Pointer B,
+       static void denseDenseMatMult(cublasHandle handle, String instName, 
Pointer C, Pointer A, Pointer B,
                        CuMatMultParameters param) {
                long t0 = ConfigurationManager.isFinegrainedStatistics() ? 
System.nanoTime() : 0;
                String kernel = null;
@@ -388,19 +396,19 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
                        transb = reverseCublasOp(transb);
                        int rightNumRows = (transb == 
CUSPARSE_OPERATION_TRANSPOSE) ? param.k : param.n;
                        int rightNumCols = (transb == 
CUSPARSE_OPERATION_TRANSPOSE) ? param.n : param.k;
-                       cudaSupportFunctions.cublasgemv(handle, transb, 
rightNumRows, rightNumCols, one(), B, param.ldb, A, 1, zero(), C, 1);
+                       cudaSupportFunctions.cublasgemv(handle, transb, 
rightNumRows, rightNumCols, param.alpha, B, param.ldb, A, 1, param.beta, C, 1);
                        kernel = 
GPUInstruction.MISC_TIMER_DENSE_VECTOR_DENSE_MATRIX_LIB;
                } else if (param.n == 1) {
                        // Matrix-vector multiply
                        LOG.debug(" GPU Dense Matrix-Vector Multiply");
                        int leftNumRows = (transa == 
CUSPARSE_OPERATION_NON_TRANSPOSE) ? param.m : param.k;
                        int leftNumCols = (transa == 
CUSPARSE_OPERATION_NON_TRANSPOSE) ? param.k : param.m;
-                       cudaSupportFunctions.cublasgemv(handle, transa, 
leftNumRows, leftNumCols, one(), A, param.lda, B, 1, zero(), C, 1);
+                       cudaSupportFunctions.cublasgemv(handle, transa, 
leftNumRows, leftNumCols, param.alpha, A, param.lda, B, 1, param.beta, C, 1);
                        kernel = 
GPUInstruction.MISC_TIMER_DENSE_MATRIX_DENSE_VECTOR_LIB;
                } else {
                        LOG.debug(" GPU Dense-Dense Matrix Multiply ");
-                       cudaSupportFunctions.cublasgemm(handle, transa, transb, 
param.m, param.n, param.k, one(), A, param.lda, B, param.ldb,
-                                       zero(), C, param.ldc);
+                       cudaSupportFunctions.cublasgemm(handle, transa, transb, 
param.m, param.n, param.k, param.alpha, A, param.lda, B, param.ldb,
+                                       param.beta, C, param.ldc);
                        kernel = 
GPUInstruction.MISC_TIMER_DENSE_MATRIX_DENSE_MATRIX_LIB;
                }
                if (ConfigurationManager.isFinegrainedStatistics())

http://git-wip-us.apache.org/repos/asf/systemml/blob/bd34292d/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
new file mode 100644
index 0000000..47afe3a
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
+import 
org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction.LstmOperator;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests lstm builtin function
+ */
+public class LstmTest extends GPUTests {
+
+       private final static String TEST_NAME = "LstmTests";
+       private final int seed = 42;
+       
+       private final static String builtinDML = 
"\"nn/layers/lstm_staging.dml\"";
+       private final static String nnDML = "\"nn/layers/lstm.dml\"";
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+       @Test
+       public void testLstmForward1() {
+               testLstmCuDNNWithNNBuiltinOperator(1, 1, 1, 1, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward2() {
+               testLstmCuDNNWithNNBuiltinOperator(1, 1, 1, 1, "FALSE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward3() {
+               testLstmCuDNNWithNNBuiltinOperator(20, 13, 50, 10, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward4() {
+               testLstmCuDNNWithNNBuiltinOperator(20, 13, 50, 10, "FALSE", 
0.9);
+       }
+       
+       @Test
+       public void testLstmForward5() {
+               testLstmCuDNNWithNNBuiltinOperator(1, 3, 5, 1, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward6() {
+               testLstmCuDNNWithNNBuiltinOperator(1, 3, 5, 1, "FALSE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward7() {
+               testLstmCuDNNWithNNBuiltinOperator(20, 13, 50, 10, "TRUE", 0.1);
+       }
+       
+       @Test
+       public void testLstmForward8() {
+               testLstmCuDNNWithNNBuiltinOperator(20, 13, 50, 10, "FALSE", 
0.1);
+       }
+       
+       @Test
+       public void testLstmForward9() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward10() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward11() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward12() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9);
+       }
+       
+       public void testLstmCuDNNWithNNBuiltinOperator(int N, int T, int D, int 
M, String returnSequences, double sparsity) {
+               String scriptStr = "source(" + builtinDML + ") as lstm;\n "
+                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("output", "c");
+               List<Object> outGPUWithCuDNN = null;
+               List<Object> outGPUWithNN = null;
+               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+                       try {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+                               outGPUWithCuDNN = runOnGPU(spark, scriptStr, 
inputs, outputs);
+                               inputs = new HashMap<>();
+                               inputs.put("x", generateInputMatrix(spark, N, 
T*D, 0, 10, sparsity, seed));
+                               inputs.put("w", generateInputMatrix(spark, D+M, 
4*M, 0, 10, sparsity, seed));
+                               inputs.put("b", generateInputMatrix(spark, 1, 
4*M, 0, 10, sparsity, seed));
+                               inputs.put("out0", generateInputMatrix(spark, 
N, M, 0, 10, sparsity, seed));
+                               inputs.put("c0", generateInputMatrix(spark, N, 
M, 0, 10, sparsity, seed));
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.DENSE_NN;
+                               outGPUWithNN = runOnGPU(spark, scriptStr, 
inputs, outputs);
+                       }
+                       finally {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+                       }
+               }
+               assertEqualObjects(outGPUWithCuDNN.get(0), outGPUWithNN.get(0));
+               assertEqualObjects(outGPUWithCuDNN.get(1), outGPUWithNN.get(1));
+       }
+       
+       public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String 
returnSequences, double sparsity) {
+               String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+               String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+                               + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0)";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("output", "c");
+               List<Object> outGPUWithCuDNN = null;
+               List<Object> outCPUWithNN = null;
+               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+                       try {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+                               outGPUWithCuDNN = runOnGPU(spark, scriptStr1, 
inputs, outputs);
+                               outCPUWithNN = runOnCPU(spark, scriptStr2, 
inputs, outputs);
+                       }
+                       finally {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+                       }
+               }
+               assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+               assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+       }
+       
+       @Test
+       public void testLstmBackward1() {
+               testLstmBackwardCuDNNWithNNBuiltinOperator(1, 1, 1, 1, "TRUE", 
0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward2() {
+               testLstmBackwardCuDNNWithNNBuiltinOperator(1, 1, 1, 1, "FALSE", 
0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward3() {
+               testLstmBackwardCuDNNWithNNBuiltinOperator(20, 13, 50, 10, 
"TRUE", 0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward4() {
+               testLstmBackwardCuDNNWithNNBuiltinOperator(20, 13, 50, 10, 
"FALSE", 0.9, 0.9);
+       }
+       
+//     @Test
+//     public void testLstmBackward5() {
+//             testLstmBackwardCuDNNWithNNBuiltinOperator(20, 13, 50, 10, 
"TRUE", 0.9, 0.1);
+//     }
+//     
+//     @Test
+//     public void testLstmBackward6() {
+//             testLstmBackwardCuDNNWithNNBuiltinOperator(20, 13, 50, 10, 
"FALSE", 0.9, 0.1);
+//     }
+       
+       
+       @Test
+       public void testLstmBackward7() {
+               testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward8() {
+               testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward9() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward10() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 
0.9);
+       }
+       
+//     @Test
+//     public void testLstmBackward11() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 
0.1);
+//     }
+//     
+//     @Test
+//     public void testLstmBackward12() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 
0.1);
+//     }
+       
+       public void testLstmBackwardCuDNNWithNNBuiltinOperator(int N, int T, 
int D, int M, String returnSequences, double sparsity, 
+                       double weightSparsity) {
+               boolean returnSequences1 = returnSequences.equals("TRUE");
+                               
+               String scriptStr = "source(" + builtinDML + ") as lstm;\n "
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+               inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
weightSparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", 
"dc0");
+               List<Object> outGPUWithCuDNN = null;
+               List<Object> outGPUWithNN = null;
+               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+                       try {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+                               outGPUWithCuDNN = runOnGPU(spark, scriptStr, 
inputs, outputs);
+                               inputs = new HashMap<>();
+                               inputs.put("dout", generateInputMatrix(spark, 
N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+                               inputs.put("dc", generateInputMatrix(spark, N, 
M, 0, 10, sparsity, seed));
+                               inputs.put("x", generateInputMatrix(spark, N, 
T*D, 0, 10, sparsity, seed));
+                               inputs.put("w", generateInputMatrix(spark, D+M, 
4*M, 0, 10, weightSparsity, seed));
+                               inputs.put("b", generateInputMatrix(spark, 1, 
4*M, 0, 10, sparsity, seed));
+                               inputs.put("out0", generateInputMatrix(spark, 
N, M, 0, 10, sparsity, seed));
+                               inputs.put("c0", generateInputMatrix(spark, N, 
M, 0, 10, sparsity, seed));
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.DENSE_NN;
+                               outGPUWithNN = runOnGPU(spark, scriptStr, 
inputs, outputs);
+                       }
+                       finally {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+                       }
+               }
+               assertEqualObjects(outGPUWithCuDNN.get(0), outGPUWithNN.get(0));
+               assertEqualObjects(outGPUWithCuDNN.get(1), outGPUWithNN.get(1));
+               assertEqualObjects(outGPUWithCuDNN.get(2), outGPUWithNN.get(2));
+               assertEqualObjects(outGPUWithCuDNN.get(3), outGPUWithNN.get(3));
+               assertEqualObjects(outGPUWithCuDNN.get(4), outGPUWithNN.get(4));
+       }
+       
+       public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int 
M, String returnSequences, double sparsity,
+                       double weightSparsity) {
+               boolean returnSequences1 = returnSequences.equals("TRUE");
+               
+               String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+               String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+                               + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0, cache_out, cache_c, cache_ifog);";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+               inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
weightSparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", 
"dc0");
+               List<Object> outGPUWithCuDNN = null;
+               List<Object> outCPUWithNN = null;
+               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+                       try {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+                               outGPUWithCuDNN = runOnGPU(spark, scriptStr1, 
inputs, outputs);
+                       }
+                       finally {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+                       }
+                       outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, 
outputs);
+               }
+               assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+               assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+               assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
+               assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
+               assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+       }
+}

Reply via email to