This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 91467c1  [SYSTEMML-540] Improve the performance of GPU lstm backward 
operator by passing the state
91467c1 is described below

commit 91467c164202f70c5a85ba7e0f7f9fcd16ddca1b
Author: Niketan Pansare <npan...@us.ibm.com>
AuthorDate: Tue Mar 19 12:30:01 2019 -0700

    [SYSTEMML-540] Improve the performance of GPU lstm backward operator by 
passing the state
    
    - The lstm builtin function extended to return state: [out, c, state] = 
lstm(X, W, b, out0, c0, return_sequences)
    - The lstm_backward builtin function extended to accept state: [dX, dW, db, 
dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, dout, dc, state)
    - Updated the DML documentation to reflect this change.
    - Updated the release documentation.
    
    Closes #856.
---
 conf/SystemML-config.xml.template                  |   3 +
 docs/dml-language-reference.md                     |  21 +-
 docs/release-process.md                            |  25 +-
 scripts/nn/layers/lstm_staging.dml                 |  12 +-
 src/main/java/org/apache/sysml/conf/DMLConfig.java |   4 +-
 .../sysml/parser/BuiltinFunctionExpression.java    |  14 +-
 .../org/apache/sysml/parser/StatementBlock.java    |  13 +-
 .../controlprogram/caching/CacheableData.java      |  10 +
 .../runtime/instructions/cp/DnnCPInstruction.java  |  72 +++++-
 .../instructions/gpu/DnnGPUInstruction.java        | 278 ++++++++++++++++-----
 .../instructions/gpu/context/GPUObject.java        |   2 +-
 .../sysml/runtime/matrix/data/LibMatrixCuDNN.java  | 163 ++++++------
 .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java    |  19 +-
 .../runtime/matrix/data/LibMatrixCuMatMult.java    |   3 +
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     |   5 +-
 .../java/org/apache/sysml/test/gpu/LstmTest.java   |  10 +-
 16 files changed, 443 insertions(+), 211 deletions(-)

diff --git a/conf/SystemML-config.xml.template 
b/conf/SystemML-config.xml.template
index b9189b1..17cc2cc 100644
--- a/conf/SystemML-config.xml.template
+++ b/conf/SystemML-config.xml.template
@@ -118,4 +118,7 @@
    <!-- Should perform recomputation of activations such as ReLU to reduce 
memory consumption. Set this to true
    when performing inference or for training very large networks (default: 
false) -->
    <sysml.gpu.recompute.activations>false</sysml.gpu.recompute.activations>
+   
+   <!-- Should SystemML runtime force the lstm builtin functions to use the 
CuDNN kernels (default: true) -->
+   <sysml.gpu.lstm.force.cudnn>true</sysml.gpu.lstm.force.cudnn>
 </root>
\ No newline at end of file
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 6f1c854..f64b6ea 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1521,16 +1521,17 @@ The images are assumed to be stored NCHW format, where 
N = batch size, C = #chan
 Hence, the images are internally represented as a matrix with dimension (N, C 
* H * W).
 
 
-| Function name                               | Input matrices           | 
Dimension of first input matrix                           | Dimension of second 
input matrix (if applicable)          | Dimension of (first) output matrix      
                                                    | Input Parameters          
                                                                                
                                                                                
    | Notes       [...]
-|---------------------------------------------|--------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------
 [...]
-| conv2d                                      | input, filter            | 
[batch_size X num_channels* height_image* width_image]    | [num_filters X 
num_channels* height_filter* width_filter] | [batch_size X num_channels_out* 
height_out* width_out]                                      | stride=[stride_h, 
stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, 
height_image, width_image], filter_shape=[num_filters, num_channels, 
height_filter, width_filter] | Performs 2D [...]
-| conv2d_backward_filter                      | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels_out* height_out* width_out]    | [num_filters X num_channels* 
height_filter* width_filter]                                   | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter, width_filter] | Computes th [...]
-| conv2d_backward_data                        | filter, dout             | 
[num_filters X num_channels* height_filter* width_filter] | [batch_size X 
num_channels_out* height_out* width_out]    | [batch_size X num_channels* 
height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter, width_filter] | Computes th [...]
-| max_pool, avg_pool                          | input                    | 
[batch_size X num_channels* height_image* width_image]    |                     
                                      | [batch_size X num_channels* height_out* 
width_out]                                          | stride=[stride_h, 
stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, 
height_image, width_image], pool_size=[height_pool, width_pool]                 
                  | Performs ma [...]
-| max_pool_backward, avg_pool_backward        | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_out* width_out]        | [batch_size X num_channels* 
height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], pool_size=[height_pool, width_pool]   
                                | Computes th [...]
-| bias_add                                    | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Adds the bi [...]
-| bias_multiply                               | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Multiplies  [...]
-| lstm                                        | X,  W, bias, out0, c0    | 
[batch_size X seq_length*num_features]                    | 
[num_features+hidden_size X 4*hidden_size]                | [batch_size X 
seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | 
return_sequences                                                                
                                                                                
                              | Perform com [...]
+| Function name                               | Input matrices                 
                     | Dimension of first input matrix                          
 | Dimension of second input matrix (if applicable)          | Dimension of 
(first) output matrix                                                          
| Input Parameters                                                              
                                                                                
                   [...]
+|---------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 [...]
+| conv2d                                      | input, filter                  
                     | [batch_size X num_channels* height_image* width_image]   
 | [num_filters X num_channels* height_filter* width_filter] | [batch_size X 
num_channels_out* height_out* width_out]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter,  [...]
+| conv2d_backward_filter                      | input, dout                    
                     | [batch_size X num_channels* height_image* width_image]   
 | [batch_size X num_channels_out* height_out* width_out]    | [num_filters X 
num_channels* height_filter* width_filter]                                   | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter,  [...]
+| conv2d_backward_data                        | filter, dout                   
                     | [num_filters X num_channels* height_filter* 
width_filter] | [batch_size X num_channels_out* height_out* width_out]    | 
[batch_size X num_channels* height_image* width_image]                          
            | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
filter_shape=[num_filters, num_channels, height_filter,  [...]
+| max_pool, avg_pool                          | input                          
                     | [batch_size X num_channels* height_image* width_image]   
 |                                                           | [batch_size X 
num_channels* height_out* width_out]                                          | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], pool_size=[height_pool, width_pool]   
                   [...]
+| max_pool_backward, avg_pool_backward        | input, dout                    
                     | [batch_size X num_channels* height_image* width_image]   
 | [batch_size X num_channels* height_out* width_out]        | [batch_size X 
num_channels* height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], pool_size=[height_pool, width_pool]   
                   [...]
+| bias_add                                    | input, bias                    
                     | [batch_size X num_channels* height_image* width_image]   
 | [num_channels X 1]                                        | [batch_size X 
num_channels* height_image* width_image]                                      | 
                                                                                
                                                                                
                 [...]
+| bias_multiply                               | input, bias                    
                     | [batch_size X num_channels* height_image* width_image]   
 | [num_channels X 1]                                        | [batch_size X 
num_channels* height_image* width_image]                                      | 
                                                                                
                                                                                
                 [...]
+| lstm                                        | X,  W, bias, out0, c0          
                     | [N X T*D]                                                
 | [D+M X 4M]                                                | [N X T*M] if 
given_sequences is true else [ N X M ]                                         
| return_sequences                                                              
                                                                                
                   [...]
+| lstm_backward                               | X, W, b, out0, c0, 
given_sequences, dout, dc, state | [N X T*M] if given_sequences is true else [ 
N X M]        | [N X M]                                                   | [N 
X T*D]                                                                          
         | return_sequences                                                     
                                                                                
                            [...]
 
 Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are 
deprecated and will be removed in the next release. The `lstm` builtin function 
is in experimental phase and is only supported for the GPU backend. 
 
diff --git a/docs/release-process.md b/docs/release-process.md
index 3798ec7..dec6b15 100644
--- a/docs/release-process.md
+++ b/docs/release-process.md
@@ -255,22 +255,19 @@ this OS X example.
 
 ## Python Tests
 
-For Spark 1.*, the Python tests at (`src/main/python/tests`) can be executed 
in the following manner:
+Compile SystemML distribution:
 
-       PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar 
test_matrix_agg_fn.py
-       PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar 
test_matrix_binary_op.py
-       PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar 
test_mlcontext.py
-       PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar 
test_mllearn_df.py
-       PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar 
test_mllearn_numpy.py
+       mvn package -P distribution
+       cd src/main/python/tests/
 
-For Spark 2.*, pyspark can't be used to run the Python tests, so they can be 
executed using
-spark-submit:
+For Spark 2.*, the Python tests at (`src/main/python/tests`) can be executed 
in the following manner:
 
-       spark-submit --driver-class-path SystemML.jar test_matrix_agg_fn.py
-       spark-submit --driver-class-path SystemML.jar test_matrix_binary_op.py
-       spark-submit --driver-class-path SystemML.jar test_mlcontext.py
-       spark-submit --driver-class-path SystemML.jar test_mllearn_df.py
-       spark-submit --driver-class-path SystemML.jar test_mllearn_numpy.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_matrix_agg_fn.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_matrix_binary_op.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_mlcontext.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_mllearn_df.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_mllearn_numpy.py
+       PYSPARK_PYTHON=python3 spark-submit --driver-class-path 
../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar
 test_nn_numpy.py
 
 
 ## Check LICENSE and NOTICE Files
@@ -385,7 +382,7 @@ file and remove all the `@Ignore` annotations from all the 
tests. Then run the N
 # Run other GPU Unit Tests 
 
        rm result.txt
-       for t in AggregateUnaryOpTests  BinaryOpTests  
MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  
MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests 
UnaryOpTests
+       for t in AggregateUnaryOpTests  BinaryOpTests  
MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  
MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests 
UnaryOpTests LstmTest LstmCPUTest
        do
                mvn -Dit.test="org.apache.sysml.test.gpu."$t verify -PgpuTests 
&> tmp.txt
                SUCCESS=`grep "BUILD SUCCESS" tmp.txt`
diff --git a/scripts/nn/layers/lstm_staging.dml 
b/scripts/nn/layers/lstm_staging.dml
index 2f71f22..f1934da 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -27,7 +27,7 @@ source("nn/layers/tanh.dml") as tanh
 
 forward = function(matrix[double] X, matrix[double] W, matrix[double] b, 
                    boolean return_sequences, matrix[double] out0, 
matrix[double] c0)
-    return (matrix[double] out, matrix[double] c) {
+    return (matrix[double] out, matrix[double] c, matrix[double] state) {
   /*
    * Computes the forward pass for an LSTM layer with M neurons.
    * The input data has N sequences of T examples, each with D features.
@@ -58,14 +58,15 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
    *      of shape (N, T*M).  Else, outputs for the final timestep, of
    *      shape (N, M).
    *  - c: Cell state for final timestep, of shape (N, M). 
+   *  - state: Intermediate state of unknown dimensions used for performance.
    */
-  out = 0; c = c0;
-  [out, c] = lstm(X, W, b, out0, c0, return_sequences)
+  out = 0; c = c0; state = c0;
+  [out, c, state] = lstm(X, W, b, out0, c0, return_sequences)
 }
 
 backward = function(matrix[double] dout, matrix[double] dc,
                     matrix[double] X, matrix[double] W, matrix[double] b,
-                    boolean given_sequences, matrix[double] out0, 
matrix[double] c0)
+                    boolean given_sequences, matrix[double] out0, 
matrix[double] c0, matrix[double] state)
     return (matrix[double] dX, matrix[double] dW, matrix[double] db,
             matrix[double] dout0, matrix[double] dc0) {
   /*
@@ -92,6 +93,7 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *      Note: This is *optional* and could just be an empty matrix.
    *  - c0: Initial cell state, of shape (N, M).
    *      Note: This is *optional* and could just be an empty matrix.
+   *  - state: state generated by the forward call.
    *
    * Outputs:
    *  - dX: Gradient wrt `X`, of shape (N, T*D).
@@ -101,7 +103,7 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *  - dc0: Gradient wrt `c0`, of shape (N, M).
    */
   dX = X; dW = W; db = b; dout0 = out0; dc0 = c0
-  [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, 
dout, dc)
+  [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, 
dout, dc, state)
 }
 
 init = function(int N, int D, int M)
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java 
b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 8459fd4..0b5ed78 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -88,6 +88,7 @@ public class DMLConfig
        public static final String SYNCHRONIZE_GPU      = 
"sysml.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after 
every instruction 
        public static final String EAGER_CUDA_FREE              = 
"sysml.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on 
rmvar
        public static final String GPU_EVICTION_POLICY  = 
"sysml.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
+       public static final String FORCE_LSTM_CUDNN             = 
"sysml.gpu.lstm.force.cudnn"; // boolean: should we force a cudnn operator for 
LSTM
        
        // Fraction of available memory to use. The available memory is 
computer when the GPUContext is created
        // to handle the tradeoff on calling cudaMemGetInfo too often.
@@ -148,6 +149,7 @@ public class DMLConfig
                _defaultVals.put(SYNCHRONIZE_GPU,        "false" );
                _defaultVals.put(CACHING_BUFFER_SIZE,    "0.15" );
                _defaultVals.put(EAGER_CUDA_FREE,        "false" );
+               _defaultVals.put(FORCE_LSTM_CUDNN,               "true" );
                _defaultVals.put(GPU_RECOMPUTE_ACTIVATIONS, "false" );
                _defaultVals.put(FLOATING_POINT_PRECISION,               
"double" );
        }
@@ -432,7 +434,7 @@ public class DMLConfig
                                CODEGEN, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, 
CODEGEN_PLANCACHE, CODEGEN_LITERALS,
                                EXTRA_FINEGRAINED_STATS, STATS_MAX_WRAP_LEN, 
PRINT_GPU_MEMORY_INFO, CACHING_BUFFER_SIZE,
                                AVAILABLE_GPUS, SYNCHRONIZE_GPU, 
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, 
EVICTION_SHADOW_BUFFERSIZE,
-                               GPU_MEMORY_ALLOCATOR, 
GPU_MEMORY_UTILIZATION_FACTOR, GPU_RECOMPUTE_ACTIVATIONS
+                               GPU_MEMORY_ALLOCATOR, 
GPU_MEMORY_UTILIZATION_FACTOR, GPU_RECOMPUTE_ACTIVATIONS, FORCE_LSTM_CUDNN
                }; 
                
                StringBuilder sb = new StringBuilder();
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index f27958f..325107c 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -219,12 +219,13 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        checkMatrixParam(getFifthExpr());
                        
                        // setup output properties
-                       if(getOutputs() == null || getOutputs().length != 2) {
+                       if(getOutputs() == null || getOutputs().length != 3) {
                                int numOutputs = getOutputs() == null ? 0 : 
getOutputs().length;
-                               raiseValidateError("The builtin function lstm 
has two outputs, but instead found: " + numOutputs, conditional);
+                               raiseValidateError("The builtin function lstm 
has three outputs, but instead found: " + numOutputs, conditional);
                        }
                        DataIdentifier out = (DataIdentifier) getOutputs()[0];
                        DataIdentifier cy = (DataIdentifier) getOutputs()[1];
+                       DataIdentifier cache = (DataIdentifier) getOutputs()[2];
                        
                        // Output1 - out: If `return_sequences` is True, 
outputs for all timesteps, else outputs for the final timestep.
                        out.setDataType(DataType.MATRIX);
@@ -238,12 +239,17 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        cy.setDimensions(getExpr(4).getOutput().getDim1(), 
getExpr(4).getOutput().getDim2());
                        
cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), 
getExpr(4).getOutput().getColumnsInBlock());
                        
+                       cache.setDataType(DataType.MATRIX);
+                       cache.setValueType(ValueType.DOUBLE);
+                       cache.setDimensions(1, 1); // Use dummy dimension for 
now. 
+                       
cache.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), 
getFirstExpr().getOutput().getColumnsInBlock());
+                       
                        break;
                }
                case LSTM_BACKWARD:
                {
-                       // Input: X, W, b, out0, c0, return_sequences, dout, cy
-                       checkNumParameters(8);
+                       // Input: X, W, b, out0, c0, return_sequences, dout, 
cy, cache
+                       checkNumParameters(9);
                        checkMatrixParam(getFirstExpr());
                        checkMatrixParam(getSecondExpr());
                        checkMatrixParam(getThirdExpr());
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java 
b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index 3988a7f..fdd2025 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -976,12 +976,17 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                        throw new LanguageException("Unexpected error.");
                
                if ( source instanceof FunctionCallIdentifier ) {
+                       // set target properties (based on type info in 
function call statement return params)
+                       FunctionCallIdentifier fci = 
(FunctionCallIdentifier)source;
+                       FunctionStatement fstmt = (FunctionStatement)_dmlProg
+                               .getFunctionStatementBlock(fci.getNamespace(), 
fci.getName()).getStatement(0);
+                       if(targetList.size() != fstmt.getOutputParams().size()) 
{
+                               // throws a controlled error if the builtin 
functions are used incorrectly
+                               fci.raiseValidateError("Incorrect number of 
outputs for the function " + fci.getNamespace() + "::" +  fci.getName() 
+                                       + ":" + targetList.size() + " != " + 
fstmt.getOutputParams().size(), conditional);
+                       }
                        for (int j =0; j< targetList.size(); j++) {
                                DataIdentifier target = targetList.get(j);
-                               // set target properties (based on type info in 
function call statement return params)
-                               FunctionCallIdentifier fci = 
(FunctionCallIdentifier)source;
-                               FunctionStatement fstmt = 
(FunctionStatement)_dmlProg
-                                       
.getFunctionStatementBlock(fci.getNamespace(), fci.getName()).getStatement(0);
                                if (fstmt == null){
                                        fci.raiseValidateError(" function " + 
fci.getName() 
                                                + " is undefined in namespace " 
+ fci.getNamespace(), conditional);
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index f0a44f7..2fa1274 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -190,6 +190,16 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        private boolean _requiresLocalWrite = false; //flag if local write for 
read obj
        private boolean _isAcquireFromEmpty = false; //flag if read from status 
empty 
        
+       // If the cacheable data is an intermediate cache, then this value is 
set to identify the type of operator that created this cache.
+       // This avoids unnecessary GPU stalling as well as supports hybrid 
forward/backward calls. 
+       private int     _intermediateCacheType = -1;
+       public void setIntermediateCacheType(int newValue) {
+               _intermediateCacheType = newValue;
+       }
+       public int getIntermediateCacheType() {
+               return _intermediateCacheType;
+       }
+       
        //spark-specific handles
        //note: we use the abstraction of LineageObjects for two reasons: (1) 
to keep track of cleanup
        //for lazily evaluated RDDs, and (2) as abstraction for environments 
that do not necessarily have spark libraries available
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 35ac5b6..9167e8c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -39,6 +39,25 @@ public class DnnCPInstruction extends UnaryCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(DnnCPInstruction.class.getName());
        private static boolean warnedUnderUtilitization = false;
        
+       public static enum LSTM_CACHE_TYPE {
+               CP_NN,
+               GPU_CUDNN,
+               GPU_NN;
+               
+               public static LSTM_CACHE_TYPE fromInteger(int x) {
+                       switch(x) {
+                               case 0:
+                                       return CP_NN;
+                               case 1:
+                                       return GPU_CUDNN;
+                               case 2:
+                                       return GPU_NN;
+                               default:
+                                       throw new 
DMLRuntimeException("Unsupported value:" + x);
+                       }
+               }
+       }
+       
        private final CPOperand _in2;
        private final CPOperand _in3;
        private final CPOperand _in4;
@@ -46,6 +65,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
        private final CPOperand _in6;
        private final CPOperand _in7;
        private final CPOperand _in8;
+       private final CPOperand _in9;
        private final CPOperand _out2;
        private final CPOperand _out3;
        private final CPOperand _out4;
@@ -63,7 +83,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                super(CPType.Dnn, null, in, out, opcode, istr);
                _in2 = in2;
                _in3 = in3;
-               _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
+               _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = 
null; _in9 = null;
                _out2 = null; _out3 = null; _out4 = null; _out5 = null;
                _stride = stride;
                _padding = padding;
@@ -112,6 +132,32 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                _in6 = in6;
                _in7 = in7;
                _in8 = in8;
+               _in9 = null;
+               _out2 = out2;
+               _out3 = out3;
+               _out4 = out4;
+               _out5 = out5;
+               _stride = null;
+               _padding = null;
+               _input_shape = null;
+               _filter_shape = null;
+               _numThreads = numThreads;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
+       public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5,
+                       CPOperand in6, CPOperand in7, CPOperand in8, CPOperand 
in9,
+                       CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
+                       double intermediateMemoryBudget, int numThreads) throws 
DMLRuntimeException {
+               super(CPType.Dnn, null, in1, out, opcode, istr);
+               _in2 = in2;
+               _in3 = in3;
+               _in4 = in4;
+               _in5 = in5;
+               _in6 = in6;
+               _in7 = in7;
+               _in8 = in8;
+               _in9 = in9;
                _out2 = out2;
                _out3 = out3;
                _out4 = out4;
@@ -262,7 +308,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm")) {
-                       InstructionUtils.checkNumFields(parts, 9);
+                       InstructionUtils.checkNumFields(parts, 10);
                        CPOperand in1 = new CPOperand(parts[1]); // X
                        CPOperand in2 = new CPOperand(parts[2]); // W
                        CPOperand in3 = new CPOperand(parts[3]); // b
@@ -271,11 +317,12 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand in6 = new CPOperand(parts[6]); // return_seq
                        CPOperand out = new CPOperand(parts[7]);  // out
                        CPOperand out2 = new CPOperand(parts[8]); // c
-                       int numThreads = Integer.parseInt(parts[9]);
-                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
+                       CPOperand out3 = new CPOperand(parts[9]); // cache
+                       int numThreads = Integer.parseInt(parts[10]);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0, numThreads);
                }
                else if (opcode.equalsIgnoreCase("lstm_backward")) {
-                       InstructionUtils.checkNumFields(parts, 14);
+                       InstructionUtils.checkNumFields(parts, 15);
                        CPOperand in1 = new CPOperand(parts[1]); // X
                        CPOperand in2 = new CPOperand(parts[2]); // W
                        CPOperand in3 = new CPOperand(parts[3]); // b
@@ -284,13 +331,14 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand in6 = new CPOperand(parts[6]); // return_seq
                        CPOperand in7 = new CPOperand(parts[7]); // dout
                        CPOperand in8 = new CPOperand(parts[8]); // dc
-                       CPOperand out = new CPOperand(parts[9]);  // dX
-                       CPOperand out2 = new CPOperand(parts[10]); // dW
-                       CPOperand out3 = new CPOperand(parts[11]); // db
-                       CPOperand out4 = new CPOperand(parts[12]); // dout0
-                       CPOperand out5 = new CPOperand(parts[13]); // dc0
-                       int numThreads = Integer.parseInt(parts[14]);
-                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
+                       CPOperand in9 = new CPOperand(parts[9]); // cache
+                       CPOperand out = new CPOperand(parts[10]);  // dX
+                       CPOperand out2 = new CPOperand(parts[11]); // dW
+                       CPOperand out3 = new CPOperand(parts[12]); // db
+                       CPOperand out4 = new CPOperand(parts[13]); // dout0
+                       CPOperand out5 = new CPOperand(parts[14]); // dc0
+                       int numThreads = Integer.parseInt(parts[15]);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, in9, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
                }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnCPInstruction: " + str);
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index fbe7c9d..b243b64 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -22,16 +22,22 @@ import java.util.ArrayList;
 import jcuda.Pointer;
 import jcuda.jcudnn.JCudnn;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.DnnCPInstruction;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNRnnAlgorithm;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DnnUtils;
@@ -45,7 +51,16 @@ public class DnnGPUInstruction extends GPUInstruction {
                NONE
        }
        
-       public static LstmOperator FORCED_LSTM_OP = LstmOperator.NONE;
+       public static LstmOperator FORCED_LSTM_OP; 
+       static {
+               
if(ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.FORCE_LSTM_CUDNN))
 {
+                       FORCED_LSTM_OP = LstmOperator.CUDNN;
+               }
+               else {
+                       FORCED_LSTM_OP = LstmOperator.NONE;
+               }
+       }
+       private static final Log LOG = 
LogFactory.getLog(DnnGPUInstruction.class.getName());
        
        private CPOperand _input1;
        private CPOperand _input2;
@@ -55,6 +70,7 @@ public class DnnGPUInstruction extends GPUInstruction {
        private CPOperand _input6;
        private CPOperand _input7;
        private CPOperand _input8;
+       private CPOperand _input9;
        private CPOperand _output;
        private CPOperand _output2;
        private CPOperand _output3;
@@ -97,6 +113,23 @@ public class DnnGPUInstruction extends GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
+       public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5, CPOperand in6, 
+                       CPOperand out, CPOperand out2, CPOperand out3, String 
opcode, String istr, 
+                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               _input1 = in1;
+               _input2 = in2;
+               _input3 = in3;
+               _input4 = in4;
+               _input5 = in5;
+               _input6 = in6;
+               _gputype = GPUINSTRUCTION_TYPE.Dnn;
+               _output = out;
+               _output2 = out2;
+               _output3 = out3;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5,
                        CPOperand in6, CPOperand in7, CPOperand in8,
                        CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
@@ -119,6 +152,29 @@ public class DnnGPUInstruction extends GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
+       public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5,
+                       CPOperand in6, CPOperand in7, CPOperand in8, CPOperand 
in9,
+                       CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
+                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               _input1 = in1;
+               _input2 = in2;
+               _input3 = in3;
+               _input4 = in4;
+               _input5 = in5;
+               _input6 = in6;
+               _input7 = in7;
+               _input8 = in8;
+               _input9 = in9;
+               _gputype = GPUINSTRUCTION_TYPE.Dnn;
+               _output = out;
+               _output2 = out2;
+               _output3 = out3;
+               _output4 = out4;
+               _output5 = out5;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode, String istr, 
                        double intermediateMemoryBudget) throws 
DMLRuntimeException {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
@@ -360,7 +416,7 @@ public class DnnGPUInstruction extends GPUInstruction {
                        return new DnnGPUInstruction(in, in2, in3, in4, out, 
opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm")) {
-                       InstructionUtils.checkNumFields(parts, 8);
+                       InstructionUtils.checkNumFields(parts, 9);
                        CPOperand in1 = new CPOperand(parts[1]); // X
                        CPOperand in2 = new CPOperand(parts[2]); // W
                        CPOperand in3 = new CPOperand(parts[3]); // b
@@ -369,10 +425,11 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand in6 = new CPOperand(parts[6]); // return_seq
                        CPOperand out = new CPOperand(parts[7]); // out
                        CPOperand out2 = new CPOperand(parts[8]); // c
-                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, out, out2, opcode, str, 0);
+                       CPOperand out3 = new CPOperand(parts[9]); // cache
+                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, out, out2, out3, opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm_backward")) {
-                       InstructionUtils.checkNumFields(parts, 13);
+                       InstructionUtils.checkNumFields(parts, 14);
                        CPOperand in1 = new CPOperand(parts[1]); // X
                        CPOperand in2 = new CPOperand(parts[2]); // W
                        CPOperand in3 = new CPOperand(parts[3]); // b
@@ -381,12 +438,13 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand in6 = new CPOperand(parts[6]); // return_seq
                        CPOperand in7 = new CPOperand(parts[7]); // dout
                        CPOperand in8 = new CPOperand(parts[8]); // dc
-                       CPOperand out = new CPOperand(parts[9]);  // dX
-                       CPOperand out2 = new CPOperand(parts[10]); // dW
-                       CPOperand out3 = new CPOperand(parts[11]); // db
-                       CPOperand out4 = new CPOperand(parts[12]); // dout0
-                       CPOperand out5 = new CPOperand(parts[13]); // dc0
-                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
+                       CPOperand cache = new CPOperand(parts[9]); // cache
+                       CPOperand out = new CPOperand(parts[10]);  // dX
+                       CPOperand out2 = new CPOperand(parts[11]); // dW
+                       CPOperand out3 = new CPOperand(parts[12]); // db
+                       CPOperand out4 = new CPOperand(parts[13]); // dout0
+                       CPOperand out5 = new CPOperand(parts[14]); // dc0
+                       return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, cache, out, out2, out3, out4, out5, opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
                        InstructionUtils.checkNumFields(parts, 7);
@@ -661,6 +719,25 @@ public class DnnGPUInstruction extends GPUInstruction {
                return (long)memRequired;
        }
        
+       private int getNumRowsLSTMTempCache(LibMatrixCuDNNRnnAlgorithm algo, 
long N, long T, long D, long M) {
+               return  toInt(
+                               // reserve space size
+                               ((long)Math.ceil( 
((double)algo.reserveSpaceSizeInBytes) / LibMatrixCUDA.sizeOfDataType )) + 
+                               // cudnnW
+                               (D+M+2)*(4*M) + 
+                               // cudnnInput
+                               (N*T*D));
+               
+       }
+       
+       private Pointer getCudnnWPointer(Pointer cachePointer, 
LibMatrixCuDNNRnnAlgorithm algo) {
+               return 
cachePointer.withByteOffset(algo.reserveSpaceSizeInBytes);
+       }
+       
+       private Pointer getCudnnInputPointer(Pointer cachePointer, 
LibMatrixCuDNNRnnAlgorithm algo, long N, long T, long D, long M) {
+               return cachePointer.withByteOffset(algo.reserveSpaceSizeInBytes 
+ ((D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType));
+       }
+       
        private void processLstmBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
                long M = out0.getNumColumns(); // hiddenSize .. since out0: (N, 
M)
@@ -676,8 +753,6 @@ public class DnnGPUInstruction extends GPUInstruction {
                long numColsX = X.getNumColumns();
                int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
                boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-               
-               // long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, 
M, D, return_sequences);
                 
                String dxName = _output.getName();
                String dwName = _output2.getName();
@@ -689,37 +764,86 @@ public class DnnGPUInstruction extends GPUInstruction {
                
                long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, 
D, return_sequences);
                
-               boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
-               
-               
                
                if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
                        N != N1 || // Use CuDNN operator when batch size of 
previous iteration is different that current iteration
-                       (!isWSparse && // Don't use CuDNN kernel when w is 
sparse.
+                       (
+                       // 
----------------------------------------------------------------------------------
+                       // Skip sparse check
+                       // !LibMatrixCUDA.isInSparseFormat(gCtx, W) && // Don't 
use CuDNN kernel when w is sparse.
+                       // 
----------------------------------------------------------------------------------
                        // When an operator is not forced, then prefer CuDNN 
kernel if it can fit in the GPU memory
-                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
+                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, 
+                                       memRequired - getSizeOnDevice(new 
MatrixObject[] {out0, W, bias, X})))) {
                        // Use CuDNN LSTM kernel
-                       Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-                       Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-                                       sysmlWPointer, sysmlBiasPointer, 
cudnnWPointer, D, M);
-                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
-                       Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-                                       xPointer, cudnnInput, N, D, T*D, N*T*D);
-                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
                        Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
-                       LibMatrixCuDNN.cuDNNLstmBackward(ec, gCtx, instName, 
-                                       cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
-                                       dxName, dwName, dbName, dhxName, 
dcxName, // output 
-                                       return_sequences, N, M, D, T);
-                       gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-                       gCtx.cudaFreeHelper(instName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+                       try(LibMatrixCuDNNRnnAlgorithm algo = 
+                                       new LibMatrixCuDNNRnnAlgorithm(ec, 
gCtx, instName, "lstm", toInt(N), toInt(T), toInt(M), toInt(D), true)) {
+                               Pointer cachePtr = null;
+                               try {
+                                       
switch(DnnCPInstruction.LSTM_CACHE_TYPE.fromInteger(ec.getMatrixObject(_input9).getIntermediateCacheType()))
 {
+                                               case GPU_CUDNN:
+                                                       cachePtr = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, 
+                                                                       
getMatrixInputForGPUInstruction(ec, _input9.getName()), instName, 
+                                                                       
getNumRowsLSTMTempCache(algo, N, T, D, M), 1);
+                                                       break;
+                                               case CP_NN:
+                                                       LOG.warn("Invoking 
CuDNN lstm backward operator, but the intermediate state was generated by CP 
lstm nn operator");
+                                                       break;
+                                               case GPU_NN:
+                                                       LOG.warn("Invoking 
CuDNN lstm backward operator, but the intermediate state was generated by GPU 
lstm nn operator");
+                                                       break;
+                                               default:
+                                                       LOG.warn("Invoking 
CuDNN lstm forward redundantly in the backward operator. Found unknown cache 
type.");
+                                                       break;
+                                       }
+                               }
+                               catch(DMLRuntimeException e) {
+                                       LOG.warn("Invoking CuDNN lstm forward 
redundantly in the backward operator");
+                               }
+                               if (algo.reserveSpaceSizeInBytes != 0) {
+                                       algo.reserveSpace = cachePtr;
+                               }
+                               else {
+                                       algo.reserveSpace = new Pointer();
+                               }
+                               
+                               Pointer cudnnWPointer = null;
+                               Pointer cudnnInput = null;
+                               if(cachePtr != null) {
+                                       cudnnWPointer = 
getCudnnWPointer(cachePtr, algo);
+                                       cudnnInput = 
getCudnnInputPointer(cachePtr, algo, N, T, D, M);
+                               }
+                               else {
+                                       Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+                                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                                       cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+                                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+                                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+                                                       sysmlWPointer, 
sysmlBiasPointer, cudnnWPointer, D, M);
+                                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+                                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+                                       Pointer xPointer = 
LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
+                                       cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+                                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+                                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+                                                       xPointer, cudnnInput, 
N, D, T*D, N*T*D);
+                                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                               }
+                               
+                               LibMatrixCuDNN.cuDNNLstmBackward(ec, gCtx, 
instName, 
+                                               cudnnInput, out0Pointer, 
c0Pointer, cudnnWPointer, doutName, dcyName,  // input
+                                               dxName, dwName, dbName, 
dhxName, dcxName, // output 
+                                               return_sequences, N, M, D, T, 
algo);
+                               if(cachePtr != null) {
+                                       
ec.releaseMatrixInputForGPUInstruction(_input9.getName());
+                               }
+                               else {
+                                       gCtx.cudaFreeHelper(instName, 
cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+                                       gCtx.cudaFreeHelper(instName, 
cudnnInput, gCtx.EAGER_CUDA_FREE);
+                               }
+                       }
+                       
                }
                else {
                        if(N != N1) {
@@ -727,6 +851,8 @@ public class DnnGPUInstruction extends GPUInstruction {
                                                " is different than the batch 
size of current iteration " + N);
                        }
                        
+                       LOG.info("Switching to gpu lstm nn backward operator. 
(CuDNN memory requirement=" + String.format("%.3f", memRequired*1e-6) + " MB.");
+                       
                        Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
                        Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
                        Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
@@ -781,6 +907,14 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixInputForGPUInstruction(_input5.getName());
        }
        
+       private long getSizeOnDevice(MatrixObject[] mObjects) {
+               long ret = 0;
+               for(MatrixObject mo : mObjects) {
+                       ret += mo.getGPUObject(gCtx).getSizeOnDevice();
+               }
+               return ret;
+       }
+       
        private void processLstmInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
                // input  X:(N, T*D),   ==> (T, D, N)
@@ -801,42 +935,62 @@ public class DnnGPUInstruction extends GPUInstruction {
                long numColsX = X.getNumColumns();
                long T = numColsX/D; // since X:(N, T*D) ... seqLength
                boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-               
+                               
                long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, 
D, return_sequences);
                
-               boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
-               
                if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
                        N != N1 || // Use CuDNN operator when batch size of 
previous iteration is different that current iteration
-                       (!isWSparse && // Don't use CuDNN kernel when w is 
sparse.
+                       (
+                       // 
----------------------------------------------------------------------------------
+                       // Skip sparse check
+                       // !LibMatrixCUDA.isInSparseFormat(gCtx, W) && // Don't 
use CuDNN kernel when w is sparse.
+                       // 
----------------------------------------------------------------------------------
                        // When an operator is not forced, then prefer CuDNN 
kernel if it can fit in the GPU memory
-                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
-                       Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-                       Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-                       Pointer cudnnWPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-                                       sysmlWPointer, sysmlBiasPointer, 
cudnnWPointer, toInt(D), toInt(M));
-                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
-                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
-                       // Beause the matrices are released immediately, the 
output for transpose need not be taken into account
-                       Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
-                       Pointer cudnnInput = gCtx.allocate(instName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-                                       xPointer, cudnnInput, toInt(N), 
toInt(D), toInt(T*D), toInt(N*T*D));
-                       
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
-                       LibMatrixCuDNN.cuDNNLstm(ec, gCtx, instName, 
cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, 
_output.getName(), _output2.getName(), 
-                                       toInt(N), toInt(M), toInt(D), toInt(T));
-                       gCtx.cudaFreeHelper(instName, cudnnWPointer, 
gCtx.EAGER_CUDA_FREE);
-                       gCtx.cudaFreeHelper(instName, cudnnInput, 
gCtx.EAGER_CUDA_FREE);
+                       FORCED_LSTM_OP == LstmOperator.NONE && 
gCtx.getMemoryManager().canAllocate(instName, 
+                                       memRequired - getSizeOnDevice(new 
MatrixObject[] {out0, W, bias, X})))) {
+                       Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
+                       try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", toInt(N), toInt(T), 
toInt(M), toInt(D), true)) {
+                               int numRows = getNumRowsLSTMTempCache(algo, N, 
T, D, M);
+                               ec.setMetaData(_output3.getName(), numRows, 1);
+                               Pointer cachePtr = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName,  _output3.getName(), 
numRows, 1);
+                               if(algo.reserveSpaceSizeInBytes != 0) {
+                                       algo.reserveSpace = cachePtr;
+                               }
+                               else {
+                                       algo.reserveSpace = new Pointer();
+                               }
+                               
+                               // Compute cudnnWPointer
+                               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+                               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+                               Pointer cudnnWPointer = 
getCudnnWPointer(cachePtr, algo); 
+                               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+                                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+                                               sysmlWPointer, 
sysmlBiasPointer, cudnnWPointer, toInt(D), toInt(M));
+                               
ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
+                               
ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
+                               
+                               // Compute cudnnInput
+                               // Because the matrices are released 
immediately, the output for transpose need not be taken into account
+                               Pointer xPointer = 
LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
+                               Pointer cudnnInput = 
getCudnnInputPointer(cachePtr, algo, N, T, D, M);
+                               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+                                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+                                               xPointer, cudnnInput, toInt(N), 
toInt(D), toInt(T*D), toInt(N*T*D));
+                               
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+                               LibMatrixCuDNN.cuDNNLstm(ec, gCtx, instName, 
cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, 
_output.getName(), _output2.getName(), 
+                                               N, M, D, T, algo);
+                               
ec.getMatrixObject(_output3.getName()).setIntermediateCacheType(DnnCPInstruction.LSTM_CACHE_TYPE.GPU_CUDNN.ordinal());
+                               
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+                       }
+                       
                }
                else {
                        if(N != N1) {
                                throw new DMLRuntimeException("Unsupported 
operation: The batch size of previous iteration " + N1 + 
                                                " is different than the batch 
size of current iteration " + N);
                        }
+                       LOG.info("Switching to gpu lstm nn operator. (CuDNN 
memory requirement=" + String.format("%.3f", memRequired*1e-6) + " MB.");
                        
                        Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
                        Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, 
X, instName); 
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 04af229..9d263aa 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -797,7 +797,7 @@ public class GPUObject {
                setSparseMatrixCudaPointer(tmp);
        }
 
-       protected long getSizeOnDevice() {
+       public long getSizeOnDevice() {
                long rlen = mat.getNumRows();
                long clen = mat.getNumColumns();
                long nnz = mat.getNnz();
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index e496ddb..4151234 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -1032,16 +1032,15 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                                        
LibMatrixCuMatMult.denseSparseMatMult(gCtx.getCusparseHandle(), instName, 
dinput, difog_raw, wSparsePointer, param2);
                                }
                        }
-                       else
+                       else {
                                
LibMatrixCuMatMult.denseDenseMatMult(gCtx.getCublasHandle(), instName, dinput, 
difog_raw, wDensePointer, param2);
+                       }
                        
                        // db = db + colSums(difog_raw)  # shape (1, 4M)
-                       reduceCol(gCtx, instName, "reduce_col_sum", difog_raw, 
tmpDb, 1, toInt(4*M));
+                       reduceCol(gCtx, instName, "reduce_col_sum", difog_raw, 
tmpDb, toInt(N), toInt(4*M));
                        matrixMatrixOp(gCtx, instName, tmpDb, db, 1, 
toInt(4*M), VectorShape.NONE.code(), VectorShape.NONE.code(), db, 
                                        new 
BinaryOperator(Plus.getPlusFnObject()));
                        
-                       // jcuda.runtime.JCuda.cudaDeviceSynchronize();
-                       
                        int size = toInt(Math.max(N*D, N*M));
                        
getCudaKernels(gCtx).launchKernel("postProcessNNLstmBackward",
                                        
ExecutionConfig.getConfigForSimpleVectorOperations(size),
@@ -1177,18 +1176,20 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param M hidden size
         * @param D number of features
         * @param T sequence length
+        * @param algo rnn algorithm
         * @throws DMLRuntimeException if error
         */
        public static void cuDNNLstm(ExecutionContext ec, GPUContext gCtx, 
String instName,
                        Pointer X,  Pointer wPointer, Pointer out0, Pointer c0, 
boolean return_sequences,
-                       String outputName, String cyName, int N, int M, int D, 
int T) throws DMLRuntimeException {
-               cuDNNSingleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
+                       String outputName, String cyName, long N, long M, long 
D, long T, LibMatrixCuDNNRnnAlgorithm algo) throws DMLRuntimeException {
+               cuDNNSingleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T, 
algo);
        }
        
        private static void 
cuDNNSingleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, 
String instName,
                        Pointer x, Pointer hx, Pointer cx, Pointer wPointer,  
// input
                        String outputName, String cyName,                       
                 // output
-                       String rnnMode, boolean return_sequences, int N, int M, 
int D, int T) throws DMLRuntimeException {
+                       String rnnMode, boolean return_sequences, long N, long 
M, long D, long T,
+                       LibMatrixCuDNNRnnAlgorithm algo) throws 
DMLRuntimeException {
                boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
                if(LOG.isDebugEnabled()) {
                        long memRequired = (N*T*M + 2*N*M + 
N*T*M)*sizeOfDataType;
@@ -1201,25 +1202,23 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, 
instName, cyName, N, M) : new Pointer();
                // Pointer wPointer = getDensePointerForCuDNN(gCtx, w, 
instName, D+M+2, 4*M);
                
-               try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, 
wPointer)) {
-                       JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), 
algo.rnnDesc, T, 
-                                       algo.xDesc, x, 
-                                       algo.hxDesc, hx, 
-                                       algo.cxDesc, cx, 
-                                       algo.wDesc, wPointer, 
-                                       algo.yDesc, cudnnYPointer, 
-                                       algo.hyDesc, hyPointer, 
-                                       algo.cyDesc, cyPointer, 
-                                       algo.workSpace, algo.sizeInBytes, 
-                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
-               }
+               JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
+                               algo.xDesc, x, 
+                               algo.hxDesc, hx, 
+                               algo.cxDesc, cx, 
+                               algo.wDesc, wPointer, 
+                               algo.yDesc, cudnnYPointer, 
+                               algo.hyDesc, hyPointer, 
+                               algo.cyDesc, cyPointer, 
+                               algo.workSpace, algo.sizeInBytes, 
+                               algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
                
                if(return_sequences) {
                        gCtx.cudaFreeHelper(instName, hyPointer, 
gCtx.EAGER_CUDA_FREE);
                        Pointer sysmlYPointer = getDenseOutputPointer(ec, gCtx, 
instName, outputName, N, T*M);
                        
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_output",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*M),
-                                       sysmlYPointer, cudnnYPointer, N, T, M, 
N*T*M);
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*M)),
+                                       sysmlYPointer, cudnnYPointer, toInt(N), 
toInt(T), toInt(M), toInt(N*T*M));
                }
                gCtx.cudaFreeHelper(instName, cudnnYPointer, 
gCtx.EAGER_CUDA_FREE);
        }
@@ -1227,7 +1226,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
        public static void cuDNNLstmBackward(ExecutionContext ec, GPUContext 
gCtx, String instName,
                        Pointer x, Pointer hx, Pointer cx, Pointer wPointer, 
String doutName, String dcyName,  // input
                        String dxName, String dwName, String dbName, String 
dhxName, String dcxName,    // output
-                       boolean return_sequences, long N, long M, long D, long 
T) throws DMLRuntimeException {
+                       boolean return_sequences, long N, long M, long D, long 
T, LibMatrixCuDNNRnnAlgorithm algo) throws DMLRuntimeException {
                
                if(LOG.isDebugEnabled()) {
                        long memRequired = (D+M)*4*M // sysmlWPointer
@@ -1252,8 +1251,11 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                                
                // Allocate intermediate pointers computed by forward
                Pointer yPointer = gCtx.allocate(instName, 
N*T*M*sizeOfDataType);
-               try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", toInt(N), toInt(T), 
-                               toInt(M), toInt(D), true, wPointer)) {
+               
+               boolean freeReserveSpace = false;
+               if(algo.reserveSpace == null) {
+                       freeReserveSpace = true;
+                       algo.reserveSpace = gCtx.allocate(instName, 
algo.reserveSpaceSizeInBytes);
                        JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
                                        algo.xDesc, x, 
                                        algo.hxDesc, hx, 
@@ -1264,59 +1266,68 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                                        algo.cyDesc, new Pointer(), 
                                        algo.workSpace, algo.sizeInBytes, 
                                        algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
-                       
-                       Pointer cudnnDx = gCtx.allocate(instName, 
N*T*D*LibMatrixCUDA.sizeOfDataType);
-                       JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
-                                       algo.yDesc, yPointer,
-                                       // ----------------------
-                                       // Additional inputs:
-                                       algo.dyDesc, dy, 
-                                       algo.dhyDesc, new Pointer(), 
-                                       algo.dcyDesc, getDenseInputPointer(ec, 
gCtx, instName, dcyName, N, M),
-                                       // ----------------------
-                                       algo.wDesc, wPointer, 
-                                       algo.hxDesc, hx,
-                                       algo.cxDesc, cx,
-                                       // ----------------------
-                                       // Output:
-                                       algo.dxDesc, cudnnDx, 
-                                       algo.dhxDesc, getDenseOutputPointer(ec, 
gCtx, instName, dhxName, N, M), 
-                                       algo.dcxDesc, getDenseOutputPointer(ec, 
gCtx, instName, dcxName, N, M),
-                                       // ----------------------
-                                       algo.workSpace, algo.sizeInBytes, 
-                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
-                       gCtx.cudaFreeHelper(instName, dy, gCtx.EAGER_CUDA_FREE);
-                       ec.releaseMatrixInputForGPUInstruction(dcyName);
-                       ec.releaseMatrixOutputForGPUInstruction(dhxName);
-                       ec.releaseMatrixOutputForGPUInstruction(dcxName);
-                       
-                       Pointer smlDx = getDenseOutputPointer(ec, gCtx, 
instName, dxName, N, T*D);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-                                       smlDx, cudnnDx, N, D, T*D, N*T*D);
-                       ec.releaseMatrixOutputForGPUInstruction(dxName);
-                       gCtx.cudaFreeHelper(instName, cudnnDx, 
gCtx.EAGER_CUDA_FREE);
-                       
-                       // 
-------------------------------------------------------------------------------------------
-                       Pointer cudnnDwPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-                       JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
-                                       algo.xDesc, x, 
-                                       algo.hxDesc, hx, 
-                                       algo.yDesc, yPointer, 
-                                       algo.workSpace, algo.sizeInBytes, 
-                                       algo.dwDesc, cudnnDwPointer, 
-                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
-                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
-                                       
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-                                       getDenseOutputPointer(ec, gCtx, 
instName, dwName, D+M, 4*M), 
-                                       getDenseOutputPointer(ec, gCtx, 
instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
-                       gCtx.cudaFreeHelper(instName, cudnnDwPointer, 
gCtx.EAGER_CUDA_FREE);
-                       ec.releaseMatrixOutputForGPUInstruction(dwName);
-                       ec.releaseMatrixOutputForGPUInstruction(dbName);
-                       // 
-------------------------------------------------------------------------------------------
-                       
-                       gCtx.cudaFreeHelper(instName, yPointer, 
gCtx.EAGER_CUDA_FREE);
                }
+               else {
+                       if(LOG.isDebugEnabled())
+                               LOG.debug("Skipping cudnnRNNForwardTraining 
call");
+               }
+               
+               Pointer cudnnDx = gCtx.allocate(instName, 
N*T*D*LibMatrixCUDA.sizeOfDataType);
+               JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
+                               algo.yDesc, yPointer,
+                               // ----------------------
+                               // Additional inputs:
+                               algo.dyDesc, dy, 
+                               algo.dhyDesc, new Pointer(), 
+                               algo.dcyDesc, getDenseInputPointer(ec, gCtx, 
instName, dcyName, N, M),
+                               // ----------------------
+                               algo.wDesc, wPointer, 
+                               algo.hxDesc, hx,
+                               algo.cxDesc, cx,
+                               // ----------------------
+                               // Output:
+                               algo.dxDesc, cudnnDx, 
+                               algo.dhxDesc, getDenseOutputPointer(ec, gCtx, 
instName, dhxName, N, M), 
+                               algo.dcxDesc, getDenseOutputPointer(ec, gCtx, 
instName, dcxName, N, M),
+                               // ----------------------
+                               algo.workSpace, algo.sizeInBytes, 
+                               algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
+               gCtx.cudaFreeHelper(instName, dy, gCtx.EAGER_CUDA_FREE);
+               ec.releaseMatrixInputForGPUInstruction(dcyName);
+               ec.releaseMatrixOutputForGPUInstruction(dhxName);
+               ec.releaseMatrixOutputForGPUInstruction(dcxName);
+               
+               Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, 
dxName, N, T*D);
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+                               smlDx, cudnnDx, N, D, T*D, N*T*D);
+               ec.releaseMatrixOutputForGPUInstruction(dxName);
+               gCtx.cudaFreeHelper(instName, cudnnDx, gCtx.EAGER_CUDA_FREE);
+               
+               // 
-------------------------------------------------------------------------------------------
+               Pointer cudnnDwPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+               JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), 
algo.rnnDesc, toInt(T), 
+                               algo.xDesc, x, 
+                               algo.hxDesc, hx, 
+                               algo.yDesc, yPointer, 
+                               algo.workSpace, algo.sizeInBytes, 
+                               algo.dwDesc, cudnnDwPointer, 
+                               algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+                               getDenseOutputPointer(ec, gCtx, instName, 
dwName, D+M, 4*M), 
+                               getDenseOutputPointer(ec, gCtx, instName, 
dbName, 1, 4*M), cudnnDwPointer, D, M);
+               gCtx.cudaFreeHelper(instName, cudnnDwPointer, 
gCtx.EAGER_CUDA_FREE);
+               ec.releaseMatrixOutputForGPUInstruction(dwName);
+               ec.releaseMatrixOutputForGPUInstruction(dbName);
+               // 
-------------------------------------------------------------------------------------------
+               
+               gCtx.cudaFreeHelper(instName, yPointer, gCtx.EAGER_CUDA_FREE);
+               if(freeReserveSpace) {
+                       gCtx.cudaFreeHelper(instName, algo.reserveSpace, 
gCtx.EAGER_CUDA_FREE);
+                       algo.reserveSpace = null;
+               }
+               
        }
        
        
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
index a1d799d..4c2a844 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
@@ -56,10 +56,10 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
        cudnnFilterDescriptor wDesc;
        cudnnFilterDescriptor dwDesc;
        long sizeInBytes; Pointer workSpace;
-       long reserveSpaceSizeInBytes; Pointer reserveSpace;
+       public long reserveSpaceSizeInBytes; public Pointer reserveSpace;
        long dropOutSizeInBytes; Pointer dropOutStateSpace;
        public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, 
String instName, 
-                       String rnnMode, int N, int T, int M, int D, boolean 
isTraining, Pointer w) throws DMLRuntimeException {
+                       String rnnMode, int N, int T, int M, int D, boolean 
isTraining) throws DMLRuntimeException {
                this.gCtx = gCtx;
                this.instName = instName;
                
@@ -113,7 +113,7 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                dwDesc = allocateFilterDescriptor(expectedNumWeights);
                
                // Setup workspace
-               workSpace = new Pointer(); reserveSpace = new Pointer();
+               workSpace = new Pointer();
                sizeInBytes = getWorkspaceSize(T);
                if(sizeInBytes != 0) {
                        if(LOG.isDebugEnabled()) 
@@ -123,11 +123,6 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                reserveSpaceSizeInBytes = 0;
                if(isTraining) {
                        reserveSpaceSizeInBytes = getReservespaceSize(T);
-                       if (reserveSpaceSizeInBytes != 0) {
-                               if(LOG.isDebugEnabled()) 
-                                       LOG.debug("Allocating " +  
reserveSpaceSizeInBytes + " bytes for lstm reserve space.");
-                               reserveSpace = gCtx.allocate(instName, 
reserveSpaceSizeInBytes);
-                       }
                }
        }
        
@@ -277,14 +272,6 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                        }
                }
                workSpace = null;
-               if(reserveSpaceSizeInBytes != 0) {
-                       try {
-                               gCtx.cudaFreeHelper(instName, reserveSpace, 
gCtx.EAGER_CUDA_FREE);
-                       } catch (DMLRuntimeException e) {
-                               throw new RuntimeException(e);
-                       }
-               }       
-               reserveSpace = null;
                if(dropOutSizeInBytes != 0) {
                        try {
                                gCtx.cudaFreeHelper(instName, 
dropOutStateSpace, gCtx.EAGER_CUDA_FREE);
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
index 6dacf28..5d3b527 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
@@ -369,6 +369,9 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
         */
        static void denseDenseMatMult(cublasHandle handle, String instName, 
Pointer C, Pointer A, Pointer B,
                        CuMatMultParameters param) {
+               if(A == null || B == null || C == null) {
+                       throw new DMLRuntimeException("The input and output 
pointers are not allocated.");
+               }
                long t0 = ConfigurationManager.isFinegrainedStatistics() ? 
System.nanoTime() : 0;
                String kernel = null;
                param.rowToColumnMajor();
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 4c4ab74..828a809 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -117,7 +117,7 @@ public class LstmCPUTest extends GPUTests {
        
        public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String 
returnSequences, double sparsity) {
                String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0)";
                String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
                                + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
                                + T + ", " + D + ", " + returnSequences + ", 
out0, c0)";
@@ -242,7 +242,8 @@ public class LstmCPUTest extends GPUTests {
                boolean returnSequences1 = returnSequences.equals("TRUE");
                
                String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0); \n"
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
                String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
                                + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
                                + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
index 47afe3a..996b12a 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
@@ -109,7 +109,7 @@ public class LstmTest extends GPUTests {
        
        public void testLstmCuDNNWithNNBuiltinOperator(int N, int T, int D, int 
M, String returnSequences, double sparsity) {
                String scriptStr = "source(" + builtinDML + ") as lstm;\n "
-                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0)";
                
                HashMap<String, Object> inputs = new HashMap<>();
                inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
@@ -143,7 +143,7 @@ public class LstmTest extends GPUTests {
        
        public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String 
returnSequences, double sparsity) {
                String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0)";
                String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
                                + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
                                + T + ", " + D + ", " + returnSequences + ", 
out0, c0)";
@@ -237,7 +237,8 @@ public class LstmTest extends GPUTests {
                boolean returnSequences1 = returnSequences.equals("TRUE");
                                
                String scriptStr = "source(" + builtinDML + ") as lstm;\n "
-                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0); \n"
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
                
                HashMap<String, Object> inputs = new HashMap<>();
                inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
@@ -281,7 +282,8 @@ public class LstmTest extends GPUTests {
                boolean returnSequences1 = returnSequences.equals("TRUE");
                
                String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+                               + "[output, c, cache] = lstm::forward(x, w, b, 
" + returnSequences + ", out0, c0); \n"
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
                String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
                                + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
                                + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"

Reply via email to