[SYSTEMML-445] Added builtin functions for efficient computation of batch 
normalization and lstm layers.

- Following builtin functions are added: lstm, batch_norm2d and
  batch_norm2d_backward.
- The DML language documentation and the NN layers are also updated.
- Since the builtin function for lstm backward data/weights is not added
  in this commit, the nn layer for lstm is not updated. Instead a new
  lstm_staging.dml is added, which will eventually replace lstm.dml.
- The above builtin functions are only supported on GPU via CuDNN. The CP
  and Spark implementation will be added in subsequent commits.

Closes #773.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/276065f9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/276065f9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/276065f9

Branch: refs/heads/master
Commit: 276065f93e5051a19d1f86d76b2844d96f7543b3
Parents: cba082e
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Fri Jun 1 10:49:46 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Fri Jun 1 10:49:46 2018 -0700

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |  22 +-
 scripts/nn/layers/batch_norm2d.dml              |  56 +-
 scripts/nn/layers/batch_norm2d_old.dml          | 200 ++++++
 scripts/nn/layers/lstm_staging.dml              | 216 +++++++
 src/main/cpp/kernels/SystemML.cu                | 104 ++++
 src/main/cpp/kernels/SystemML.ptx               | 622 ++++++++++++++++---
 .../sysml/api/mlcontext/ScriptExecutor.java     |   2 +-
 .../java/org/apache/sysml/conf/DMLConfig.java   |   6 +-
 .../java/org/apache/sysml/hops/FunctionOp.java  |  26 +-
 .../java/org/apache/sysml/hops/ReorgOp.java     |  19 +-
 .../sysml/parser/BuiltinFunctionExpression.java | 141 ++++-
 .../org/apache/sysml/parser/DMLTranslator.java  |  13 +-
 .../org/apache/sysml/parser/Expression.java     |   1 +
 .../controlprogram/caching/CacheableData.java   |   2 +-
 .../instructions/GPUInstructionParser.java      |   7 +-
 .../gpu/ConvolutionGPUInstruction.java          | 259 +++++++-
 .../instructions/gpu/context/CSRPointer.java    |   2 +-
 .../instructions/gpu/context/GPUContext.java    |   2 +-
 .../context/GPULazyCudaFreeMemoryManager.java   |   2 +-
 .../gpu/context/GPUMatrixMemoryManager.java     |   2 +-
 .../gpu/context/GPUMemoryManager.java           |   2 +-
 .../instructions/gpu/context/GPUObject.java     |   2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |   5 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     | 273 +++++++-
 .../LibMatrixCuDNNConvolutionAlgorithm.java     |   2 +-
 .../data/LibMatrixCuDNNInputRowFetcher.java     |   2 +-
 .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java | 283 +++++++++
 .../runtime/matrix/data/LibMatrixCuMatMult.java |   2 +-
 .../SinglePrecisionCudaSupportFunctions.java    |   2 +-
 .../org/apache/sysml/utils/GPUStatistics.java   |   2 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |   2 +-
 31 files changed, 2067 insertions(+), 214 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index b4ed9c8..3212806 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1511,16 +1511,18 @@ The images are assumed to be stored NCHW format, where 
N = batch size, C = #chan
 Hence, the images are internally represented as a matrix with dimension (N, C 
* H * W).
 
 
-| Function name                               | Input matrices | Dimension of 
first input matrix                           | Dimension of second input matrix 
(if applicable)          | Dimension of output matrix                           
      | Input Parameters                                                        
                                                                                
                                      | Notes                                   
                                                                                
                          |
-|---------------------------------------------|----------------|-----------------------------------------------------------|-----------------------------------------------------------|------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------|
-| conv2d                                      | input, filter  | [batch_size X 
num_channels* height_image* width_image]    | [num_filters X num_channels* 
height_filter* width_filter] | [batch_size X num_channels_out* height_out* 
width_out]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
filter_shape=[num_filters, num_channels, height_filter, width_filter] | 
Performs 2D convolution operation                                               
                                                                  |
-| conv2d_backward_filter                      | input, dout    | [batch_size X 
num_channels* height_image* width_image]    | [batch_size X num_channels_out* 
height_out* width_out]    | [num_filters X num_channels* height_filter* 
width_filter]  | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
filter_shape=[num_filters, num_channels, height_filter, width_filter] | 
Computes the gradients wrt filter of 2D convolution                             
                                                                  |
-| conv2d_backward_data                        | filter, dout   | [num_filters 
X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* 
height_out* width_out]    | [batch_size X num_channels* height_image* 
width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
filter_shape=[num_filters, num_channels, height_filter, width_filter] | 
Computes the gradients wrt input of 2D convolution                              
                                                                  |
-| max_pool, avg_pool                          | input          | [batch_size X 
num_channels* height_image* width_image]    |                                   
                        | [batch_size X num_channels* height_out* width_out]    
     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
pool_size=[height_pool, width_pool]                                   | 
Performs max/average pooling operation                                          
                                                                  |
-| max_pool_backward, avg_pool_backward        | input, dout    | [batch_size X 
num_channels* height_image* width_image]    | [batch_size X num_channels* 
height_out* width_out]        | [batch_size X num_channels* height_image* 
width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], 
input_shape=[batch_size, num_channels, height_image, width_image], 
pool_size=[height_pool, width_pool]                                   | 
Computes the gradients wrt input of 2D max pooling, average pooling             
                                                                  |
-| bias_add                                    | input, bias    | [batch_size X 
num_channels* height_image* width_image]    | [num_channels X 1]                
                        | [batch_size X num_channels* height_image* 
width_image]     |                                                              
                                                                                
                                                 | Adds the bias (row vector of 
size num_channels) to input with the given num_channels                         
                                     |
-| bias_multiply                               | input, bias    | [batch_size X 
num_channels* height_image* width_image]    | [num_channels X 1]                
                        | [batch_size X num_channels* height_image* 
width_image]     |                                                              
                                                                                
                                                 | Multiplies the bias (row 
vector of size num_channels) to input with the given num_channels               
                                         |
-
+| Function name                               | Input matrices           | 
Dimension of first input matrix                           | Dimension of second 
input matrix (if applicable)          | Dimension of (first) output matrix      
                                                    | Input Parameters          
                                                                                
                                                                                
    | Notes                                                                     
                                                                        |
+|---------------------------------------------|--------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------|
+| conv2d                                      | input, filter            | 
[batch_size X num_channels* height_image* width_image]    | [num_filters X 
num_channels* height_filter* width_filter] | [batch_size X num_channels_out* 
height_out* width_out]                                      | stride=[stride_h, 
stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, 
height_image, width_image], filter_shape=[num_filters, num_channels, 
height_filter, width_filter] | Performs 2D convolution operation                
                                                                                
                 |
+| conv2d_backward_filter                      | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels_out* height_out* width_out]    | [num_filters X num_channels* 
height_filter* width_filter]                                   | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter, width_filter] | Computes the gradients wrt filter 
of 2D convolution                                                               
                                |
+| conv2d_backward_data                        | filter, dout             | 
[num_filters X num_channels* height_filter* width_filter] | [batch_size X 
num_channels_out* height_out* width_out]    | [batch_size X num_channels* 
height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], filter_shape=[num_filters, 
num_channels, height_filter, width_filter] | Computes the gradients wrt input 
of 2D convolution                                                               
                                 |
+| max_pool, avg_pool                          | input                    | 
[batch_size X num_channels* height_image* width_image]    |                     
                                      | [batch_size X num_channels* height_out* 
width_out]                                          | stride=[stride_h, 
stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, 
height_image, width_image], pool_size=[height_pool, width_pool]                 
                  | Performs max/average pooling operation                      
                                                                                
      |
+| max_pool_backward, avg_pool_backward        | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_out* width_out]        | [batch_size X num_channels* 
height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], pool_size=[height_pool, width_pool]   
                                | Computes the gradients wrt input of 2D max 
pooling, average pooling                                                        
                       |
+| bias_add                                    | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Adds the bias (row vector of size num_channels) to input with 
the given num_channels                                                          
    |
+| bias_multiply                               | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Multiplies the bias (row vector of size num_channels) to 
input with the given num_channels                                               
         |
+| lstm                                        | X,  W, bias, out0, c0    | 
[batch_size X seq_length*num_features]                    | 
[num_features+hidden_size X 4*hidden_size]                | [batch_size X 
seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | 
return_sequences                                                                
                                                                                
                              | Perform computation for single-layer 
unidirectional LSTM (outputs: out, carryOut, reserveSpace)                      
                             |
+| batch_norm2d                                | input                    | 
[batch_size X num_channels* height_image* width_image]    |                     
                                      | [batch_size X num_channels* 
height_image* width_image]                                      | scale, shift, 
exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, 
epsilon, momentum                                                               
                        | Performs batch normalization operation  (outputs: 
updated exponential moving average mean and variance, cache of the batch mean 
and variance)     |
+| batch_norm2d_backward                       | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_image* width_image]    | [batch_size X num_channels* 
height_image* width_image]                                      | scale, 
epsilon, cache_mean (from forward), cache_inv_var (from forward)                
                                                                                
                       | Computed backpropagation error for batch normalization 
operation                                                                       
           |
 
 Examples:
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml 
b/scripts/nn/layers/batch_norm2d.dml
index 8a8555f..2a98857 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -83,41 +83,8 @@ forward = function(matrix[double] X, matrix[double] gamma, 
matrix[double] beta,
    *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
    *      Note: This is used for performance during training.
    */
-  N = nrow(X)
-
-  if (mode == 'train') {
-    # Compute channel-wise mean and variance
-    # Since we don't have tensors, we will compute the means and variances in 
a piece-wise fashion.
-    #  - mean of total group is mean of subgroup means
-    #  - variance is the mean of the subgroup variances + the variance of the 
subgroup means
-    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
-    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # 
uncorrected variances
-    mean = rowMeans(subgrp_means)  # shape (C, 1)
-    var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
-    # Update moving averages
-    ema_mean_upd = mu*ema_mean + (1-mu)*mean
-    ema_var_upd = mu*ema_var + (1-mu)*var
-  }
-  else {
-    # Use moving averages of mean and variance during testing
-    mean = ema_mean
-    var = ema_var
-    ema_mean_upd = ema_mean
-    ema_var_upd = ema_var
-  }
-
-  # Save variable for backward pass
-  cache_mean = mean
-  cache_inv_var = 1/sqrt(var+epsilon)
-  
-  # Normalize, shift, and scale
-  # norm = (X-mean)*(var+epsilon)^(-1/2)
-  #      = (X-mean) / sqrt(var+epsilon)
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
-  # out = norm*gamma + beta
-  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
-  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
+  out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var;  cache_mean = 
ema_mean;  cache_inv_var = ema_var
+  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = 
batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
 }
 
 backward = function(matrix[double] dout, 
@@ -152,24 +119,9 @@ backward = function(matrix[double] dout,
    *  - dbeta: Gradient wrt `b`, of shape (C, 1).
    *
    */
-  N = nrow(X)
-  mean = cache_mean
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
   # Compute gradients during training
-  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
-  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
-  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
-  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) 
* dnorm,
-                          C, Hin, Win)  # shape (C, 1)
-  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), 
C, Hin, Win)
-  dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, 
Win)
-  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within 
an expression yet
-  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
-  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
-  dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, 
cols=C*Hin*Win), dmean)
-  dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
-  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
+  dX = X; dgamma = gamma; dbeta = gamma;
+  [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, 
cache_mean, cache_inv_var)
 }
 
 init = function(int C)

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/batch_norm2d_old.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d_old.dml 
b/scripts/nn/layers/batch_norm2d_old.dml
new file mode 100644
index 0000000..2aba2e6
--- /dev/null
+++ b/scripts/nn/layers/batch_norm2d_old.dml
@@ -0,0 +1,200 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D (Spatial) Batch Normalization layer.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                   int C, int Hin, int Win, string mode,
+                   matrix[double] ema_mean, matrix[double] ema_var,
+                   double mu, double epsilon)
+    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] 
ema_var_upd,
+            matrix[double] cache_mean, matrix[double] cache_inv_var) {
+  /*
+   * Computes the forward pass for a 2D (spatial) batch normalization
+   * layer.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * A spatial batch normalization layer uses the per-channel sample
+   * mean and per-channel uncorrected sample variance during training
+   * to normalize each channel of the input data.  Additionally, it
+   * introduces learnable parameters (gamma, beta) to control the
+   * amount of normalization.
+   *
+   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
+   *
+   * This implementation maintains exponential moving averages of the
+   * mean and variance during training for use during testing.
+   *
+   * Reference:
+   *  - Batch Normalization: Accelerating Deep Network Training by
+   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+   *    - https://arxiv.org/abs/1502.03167
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hin*Win).
+   *  - ema_mean_upd: Updated exponential moving average of the mean,
+   *      of shape (C, 1).
+   *  - ema_var_upd: Updated exponential moving average of the variance,
+   *      of shape (C, 1).
+   *  - cache_mean: Cache of the batch mean, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   */
+  N = nrow(X)
+
+  if (mode == 'train') {
+    # Compute channel-wise mean and variance
+    # Since we don't have tensors, we will compute the means and variances in 
a piece-wise fashion.
+    #  - mean of total group is mean of subgroup means
+    #  - variance is the mean of the subgroup variances + the variance of the 
subgroup means
+    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # 
uncorrected variances
+    mean = rowMeans(subgrp_means)  # shape (C, 1)
+    var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_inv_var = 1/sqrt(var+epsilon)
+  
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  #      = (X-mean) / sqrt(var+epsilon)
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+  # out = norm*gamma + beta
+  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
+  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
+}
+
+backward = function(matrix[double] dout, 
+                    matrix[double] cache_mean, matrix[double] cache_inv_var,
+                    matrix[double] X, matrix[double] gamma, 
+                    int C, int Hin, int Win, double epsilon)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a 2D (spatial) batch normalization
+   * layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
+   *  - cache_mean: Cache of the batch mean from the forward pass, of
+   *      shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - cache_inv_var: Cache of the inverse variance from the forward pass,
+   *      of shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - X: Input data matrix to the forward pass, of
+   *      shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dgamma: Gradient wrt `W`, of shape (C, 1).
+   *  - dbeta: Gradient wrt `b`, of shape (C, 1).
+   *
+   */
+  N = nrow(X)
+  mean = cache_mean
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+  # Compute gradients during training
+  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
+  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) 
* dnorm,
+                          C, Hin, Win)  # shape (C, 1)
+  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), 
C, Hin, Win)
+  dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, 
Win)
+  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within 
an expression yet
+  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
+  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+  dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, 
cols=C*Hin*Win), dmean)
+  dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
+  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
+}
+
+init = function(int C)
+    return (matrix[double] gamma, matrix[double] beta,
+            matrix[double] ema_mean, matrix[double] ema_var) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - C: Number of input channels (dimensionality of input depth).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   */
+   gamma = matrix(1, rows=C, cols=1)
+   beta = matrix(0, rows=C, cols=1)
+   ema_mean = matrix(0, rows=C, cols=1)
+   ema_var = matrix(1, rows=C, cols=1)
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/lstm_staging.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/lstm_staging.dml 
b/scripts/nn/layers/lstm_staging.dml
new file mode 100644
index 0000000..d0949d9
--- /dev/null
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -0,0 +1,216 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * LSTM layer.
+ */
+source("nn/layers/sigmoid.dml") as sigmoid
+source("nn/layers/tanh.dml") as tanh
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b, 
+                   boolean return_sequences, matrix[double] out0, 
matrix[double] c0)
+    return (matrix[double] out, matrix[double] c) {
+  /*
+   * Computes the forward pass for an LSTM layer with M neurons.
+   * The input data has N sequences of T examples, each with D features.
+   *
+   * In an LSTM, an internal cell state is maintained, additive
+   * interactions operate over the cell state at each timestep, and
+   * some amount of this cell state is exposed as output at each
+   * timestep.  Additionally, the output of the previous timestep is fed
+   * back in as an additional input at the current timestep.
+   *
+   * Reference:
+   *  - Long Short-Term Memory, Hochreiter, 1997
+   *    - http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, T*D).
+   *  - W: Weights, of shape (D+M, 4M).
+   *  - b: Biases, of shape (1, 4M).
+   *  - return_sequences: Whether to return `out` at all timesteps,
+   *      or just for the final timestep.
+   *  - out0: Outputs from previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - c0: Initial cell state, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *
+   * Outputs:
+   *  - out: If `return_sequences` is True, outputs for all timesteps,
+   *      of shape (N, T*M).  Else, outputs for the final timestep, of
+   *      shape (N, M).
+   *  - c: Cell state for final timestep, of shape (N, M).
+   *  - reserveSpace: reserveSpace to be passed to output (row-vector whose 
size is determined at runtime). 
+   */
+  [out, c, reserveSpace] = lstm(X, W, b, out0, c0, return_sequences)
+}
+
+# TODO:
+backward = function(matrix[double] dout, matrix[double] dc,
+                    matrix[double] X, matrix[double] W, matrix[double] b, int 
T, int D,
+                    boolean given_sequences, matrix[double] out0, 
matrix[double] c0,
+                    matrix[double] cache_out, matrix[double] cache_c, 
matrix[double] cache_ifog)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db,
+            matrix[double] dout0, matrix[double] dc0) {
+  /*
+   * Computes the backward pass for an LSTM layer with M neurons.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out`.  If `given_sequences` is `True`,
+   *      contains gradients on outputs for all timesteps, of
+   *      shape (N, T*M). Else, contains the gradient on the output
+   *      for the final timestep, of shape (N, M).
+   *  - dc: Gradient wrt `c` (from later in time), of shape (N, M).
+   *      This would come from later in time if the cell state was used
+   *      downstream as the initial cell state for another LSTM layer.
+   *      Typically, this would be used when a sequence was cut at
+   *      timestep `T` and then continued in the next batch.  If `c`
+   *      was not used downstream, then `dc` would be an empty matrix.
+   *  - X: Inputs, of shape (N, T*D).
+   *  - W: Weights, of shape (D+M, 4M).
+   *  - b: Biases, of shape (1, 4M).
+   *  - T: Length of example sequences (number of timesteps).
+   *  - D: Dimensionality of the input features.
+   *  - given_sequences: Whether `dout` is for all timesteps,
+   *      or just for the final timestep.  This is based on whether
+   *      `return_sequences` was true in the forward pass.
+   *  - out0: Outputs from previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - c0: Initial cell state, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - cache_out: Cache of outputs, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_c: Cache of cell state, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_ifog: Cache of intermediate values, of shape (T, N*4*M).
+   *      Note: This is used for performance during training.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, T*D).
+   *  - dW: Gradient wrt `W`, of shape (D+M, 4M).
+   *  - db: Gradient wrt `b`, of shape (1, 4M).
+   *  - dout0: Gradient wrt `out0`, of shape (N, M).
+   *  - dc0: Gradient wrt `c0`, of shape (N, M).
+   */
+  N = nrow(X)
+  M = as.integer(ncol(W)/4)
+  N1 = nrow(out0)
+  if(N != N1) {
+    # Allow for smaller out0 for last batch 
+    # out0 = out0[1:N,]
+    # c0 = c0[1:N,]
+    stop("Unsupported operation: The batch size of previous iteration " + N1 + 
" is different than the batch size of current iteration " + N)
+  }
+  dX = matrix(0, rows=N, cols=T*D)
+  dW = matrix(0, rows=D+M, cols=4*M)
+  db = matrix(0, rows=1, cols=4*M)
+  dout0 = matrix(0, rows=N, cols=M)
+  dc0 = matrix(0, rows=N, cols=M)
+  dct = dc
+  if (!given_sequences) {
+    # only given dout for output at final timestep, so prepend empty douts for 
all other timesteps
+    dout = cbind(matrix(0, rows=N, cols=(T-1)*M), dout)  # shape (N, T*M)
+  }
+
+  t = T
+  for (iter in 1:T) {  # each timestep in reverse order
+    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+    dout_t = dout[,(t-1)*M+1:t*M]  # shape (N, M)
+    out_t = matrix(cache_out[t,], rows=N, cols=M)  # shape (N, M)
+    ct = matrix(cache_c[t,], rows=N, cols=M)  # shape (N, M)
+    if (t == 1) {
+      out_prev = out0  # shape (N, M)
+      c_prev = c0  # shape (N, M)
+    }
+    else {
+      out_prev = matrix(cache_out[t-1,], rows=N, cols=M)  # shape (N, M)
+      c_prev = matrix(cache_c[t-1,], rows=N, cols=M)  # shape (N, M)
+    }
+    input = cbind(X_t, out_prev)  # shape (N, D+M)
+    ifog = matrix(cache_ifog[t,], rows=N, cols=4*M)
+    i = ifog[,1:M]  # input gate, shape (N, M)
+    f = ifog[,M+1:2*M]  # forget gate, shape (N, M)
+    o = ifog[,2*M+1:3*M]  # output gate, shape (N, M)
+    g = ifog[,3*M+1:4*M]  # g gate, shape (N, M)
+
+    dct = dct + o*tanh::backward(dout_t, ct)  # shape (N, M)
+    do = tanh::forward(ct) * dout_t  # output gate, shape (N, M)
+    df = c_prev * dct  # forget gate, shape (N, M)
+    dc_prev = f * dct  # shape (N, M)
+    di = g * dct  # input gate, shape (N, M)
+    dg = i * dct  # g gate, shape (N, M)
+
+    di_raw = i * (1-i) * di
+    df_raw = f * (1-f) * df
+    do_raw = o * (1-o) * do
+    dg_raw = (1-g^2) * dg
+    difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw)  # shape (N, 4M)
+
+    dW = dW + t(input) %*% difog_raw  # shape (D+M, 4M)
+    db = db + colSums(difog_raw)  # shape (1, 4M)
+    dinput = difog_raw %*% t(W)  # shape (N, D+M)
+    dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+    dout_prev = dinput[,D+1:D+M]  # shape (N, M)
+    if (t == 1) {
+      dout0 = dout_prev  # shape (N, M)
+      dc0 = dc_prev  # shape (N, M)
+    }
+    else {
+      dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev  # shape 
(N, M)
+      dct = dc_prev  # shape (N, M)
+    }
+    t = t - 1
+  }
+}
+
+init = function(int N, int D, int M)
+    return (matrix[double] W, matrix[double] b, matrix[double] out0, 
matrix[double] c0) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * We use the Glorot uniform heuristic which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * uniform weights by a factor of sqrt(6/(fan_in + fan_out)).
+   *  - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
+   *
+   * Inputs:
+   *  - N: Number of examples in batch.
+   *  - D: Dimensionality of the input features (number of features).
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (D+M, 4M).
+   *  - b: Biases, of shape (1, 4M).
+   *  - out0: Empty previous timestep output matrix, of shape (N, M).
+   *  - c0: Empty initial cell state matrix, of shape (N, M).
+   */
+  fan_in = D+M
+  fan_out = 4*M
+  scale = sqrt(6/(fan_in+fan_out))
+  W = rand(rows=D+M, cols=4*M, min=-scale, max=scale, pdf="uniform")
+  b = matrix(0, rows=1, cols=4*M)
+  out0 = matrix(0, rows=N, cols=M)
+  c0 = matrix(0, rows=N, cols=M)
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 55ebeaf..cc2d531 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -1962,6 +1962,87 @@ extern "C" __global__ void matrix_sigmoid_f(float *A, 
float *C,
   matrix_sigmoid(A, C, size);
 }
 
+template <typename T>
+__device__ void prepare_lstm_input(T* smlInput, T* cudnnInput, int N, int D, 
int TD, int size) {
+       int index = blockIdx.x * blockDim.x + threadIdx.x;
+       if(index < size) {
+               int n = index / TD;
+               int td = index % TD;
+               int t = td / D;
+               int d = td % D;
+               cudnnInput[t*N*D + n*D + d] = smlInput[index];
+       }
+}
+
+
+extern "C" __global__ void prepare_lstm_input_d(double* smlInput, double* 
cudnnInput, int N, int D, int TD, int size) {
+  prepare_lstm_input(smlInput, cudnnInput, N, D, TD, size);
+}
+
+extern "C" __global__ void prepare_lstm_input_f(float* smlInput, float* 
cudnnInput, int N, int D, int TD, int size) {
+  prepare_lstm_input(smlInput, cudnnInput, N, D, TD, size);
+}
+
+__device__ int swap_co(int offset) {
+  return (offset < 2) ? offset : (offset == 2 ? 3 : 2);
+}
+
+template <typename T>
+__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, 
int D, int M) {
+  int DM = D*M; int MM = M*M; int DM4 = DM*4; 
+  int M4 = M*4;
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  // input: cbind(X_t, out_prev) => [N, D+M], weight: [D+M, 4M]
+  // 
https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnGetRNNLinLayerMatrixParams
 states that 
+  // Elements in each weight matrix are arranged in the row-major order, but 
the column-major format works !!
+  // CuDNN gate order: i, f, c, o
+  // CuDNN weight order: w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
+  // SystemML weight order: i, f, o, c; TF weight order: i, c, f, o
+  // SystemML performs (X_t %*% W + out_prev %*% R) => [N, 4*M]
+  
+  // bias layout: bi bf bc bo 0 0 0 0
+  // where W: [DxM], R: [MxM] and b: [1x1]
+  
+  // Maximum (D+M+2)*M4 threads
+  int srcIndex = -1; int destIndex;
+  if(index < DM4) {
+    // Fill w_i, w_f, w_c and w_o
+    int localIndex = index%DM;
+    int smlRowIndex = localIndex/M; 
+    int smlColIndex = swap_co(index/(DM))*M + localIndex%M;
+    // Convert index to column-major where index = (index/(DM))*DM + 
(localIndex/M)*M + localIndex%M
+    destIndex = (index/(DM))*DM + (localIndex%M)*D + localIndex/M;
+    srcIndex = smlRowIndex*M4+smlColIndex;
+  }
+  else if(index < (D+M)*M4) {
+    // Fill r_i, r_f, r_c and r_o
+    int tmpIndex = index-DM4;
+    int localIndex = tmpIndex % MM;
+    int smlRowIndex = D + (localIndex / M);
+    int smlColIndex = swap_co(tmpIndex/(MM))*M + localIndex%M;
+    // Convert index to column-major where index = DM4 + (tmpIndex/(MM))*MM + 
(localIndex/M)*M + localIndex%M
+    destIndex = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M;
+    srcIndex = smlRowIndex*M4+smlColIndex;
+  }
+  else if(index < (D+M+1)*M4) {
+       // Fill bias
+       int tmpIndex = index - (D+M)*M4;
+       int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M;
+       cudnnWeight[index] = smlBias[smlColIndex];
+  }
+  // __syncthreads();
+  if(srcIndex != -1)
+       cudnnWeight[destIndex] = smlWeight[srcIndex];
+}
+
+extern "C" __global__ void prepare_lstm_weight_d(double* smlWeight, double* 
smlBias, double* cudnnWeight, int D, int M) {
+  prepare_lstm_weight(smlWeight, smlBias, cudnnWeight, D, M);
+}
+
+extern "C" __global__ void prepare_lstm_weight_f(float* smlWeight, float* 
smlBias, float* cudnnWeight, int D, int M) {
+  prepare_lstm_weight(smlWeight, smlBias, cudnnWeight, D, M);
+}
+
 // We can later fold it in our reduce method
 template <typename T>
 __device__ void compute_nnz(
@@ -2058,3 +2139,26 @@ extern "C" __global__ void compute_nnz_d(double 
*g_idata, double *g_odata, unsig
 extern "C" __global__ void compute_nnz_f(float *g_idata, float *g_odata, 
unsigned int n) {
        compute_nnz(g_idata, g_odata, n);
 }
+
+template <typename T>
+__device__ void prepare_lstm_output(T* smlInput, T* cudnnInput, int N, int T1, 
int M, int size) {
+       int index = blockIdx.x * blockDim.x + threadIdx.x;
+       if(index < size) {
+               int TM = T1*M;
+               int NT = T1*N;
+               int n = index / TM;
+               int tm = index % TM;
+               int t = tm / M;
+               int m = tm % M;
+               smlInput[index] = cudnnInput[t*N*M + n*M + m];
+       }
+}
+
+
+extern "C" __global__ void prepare_lstm_output_d(double* smlInput, double* 
cudnnInput, int N, int T, int M, int size) {
+  prepare_lstm_output(smlInput, cudnnInput, N, T, M, size);
+}
+
+extern "C" __global__ void prepare_lstm_output_f(float* smlInput, float* 
cudnnInput, int N, int T, int M, int size) {
+  prepare_lstm_output(smlInput, cudnnInput, N, T, M, size);
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index 1865e18..ed1a100 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -11689,6 +11689,328 @@ BB97_5:
        ret;
 }
 
+       // .globl       prepare_lstm_input_d
+.visible .entry prepare_lstm_input_d(
+       .param .u64 prepare_lstm_input_d_param_0,
+       .param .u64 prepare_lstm_input_d_param_1,
+       .param .u32 prepare_lstm_input_d_param_2,
+       .param .u32 prepare_lstm_input_d_param_3,
+       .param .u32 prepare_lstm_input_d_param_4,
+       .param .u32 prepare_lstm_input_d_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<15>;
+       .reg .f64       %fd<2>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_input_d_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_input_d_param_1];
+       ld.param.u32    %r2, [prepare_lstm_input_d_param_2];
+       ld.param.u32    %r3, [prepare_lstm_input_d_param_3];
+       ld.param.u32    %r4, [prepare_lstm_input_d_param_4];
+       ld.param.u32    %r5, [prepare_lstm_input_d_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB98_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       rem.s32         %r9, %r1, %r4;
+       div.s32         %r10, %r9, %r3;
+       rem.s32         %r11, %r9, %r3;
+       mul.wide.s32    %rd4, %r1, 8;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f64   %fd1, [%rd5];
+       div.s32         %r12, %r1, %r4;
+       mad.lo.s32      %r13, %r10, %r2, %r12;
+       mad.lo.s32      %r14, %r13, %r3, %r11;
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r14, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f64   [%rd8], %fd1;
+
+BB98_2:
+       ret;
+}
+
+       // .globl       prepare_lstm_input_f
+.visible .entry prepare_lstm_input_f(
+       .param .u64 prepare_lstm_input_f_param_0,
+       .param .u64 prepare_lstm_input_f_param_1,
+       .param .u32 prepare_lstm_input_f_param_2,
+       .param .u32 prepare_lstm_input_f_param_3,
+       .param .u32 prepare_lstm_input_f_param_4,
+       .param .u32 prepare_lstm_input_f_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<2>;
+       .reg .b32       %r<15>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_input_f_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_input_f_param_1];
+       ld.param.u32    %r2, [prepare_lstm_input_f_param_2];
+       ld.param.u32    %r3, [prepare_lstm_input_f_param_3];
+       ld.param.u32    %r4, [prepare_lstm_input_f_param_4];
+       ld.param.u32    %r5, [prepare_lstm_input_f_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB99_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       rem.s32         %r9, %r1, %r4;
+       div.s32         %r10, %r9, %r3;
+       rem.s32         %r11, %r9, %r3;
+       mul.wide.s32    %rd4, %r1, 4;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f32   %f1, [%rd5];
+       div.s32         %r12, %r1, %r4;
+       mad.lo.s32      %r13, %r10, %r2, %r12;
+       mad.lo.s32      %r14, %r13, %r3, %r11;
+       cvta.to.global.u64      %rd6, %rd2;
+       mul.wide.s32    %rd7, %r14, 4;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f32   [%rd8], %f1;
+
+BB99_2:
+       ret;
+}
+
+       // .globl       prepare_lstm_weight_d
+.visible .entry prepare_lstm_weight_d(
+       .param .u64 prepare_lstm_weight_d_param_0,
+       .param .u64 prepare_lstm_weight_d_param_1,
+       .param .u64 prepare_lstm_weight_d_param_2,
+       .param .u32 prepare_lstm_weight_d_param_3,
+       .param .u32 prepare_lstm_weight_d_param_4
+)
+{
+       .reg .pred      %p<11>;
+       .reg .b32       %r<53>;
+       .reg .f64       %fd<3>;
+       .reg .b64       %rd<15>;
+
+
+       ld.param.u64    %rd2, [prepare_lstm_weight_d_param_0];
+       ld.param.u64    %rd3, [prepare_lstm_weight_d_param_1];
+       ld.param.u64    %rd4, [prepare_lstm_weight_d_param_2];
+       ld.param.u32    %r13, [prepare_lstm_weight_d_param_3];
+       ld.param.u32    %r14, [prepare_lstm_weight_d_param_4];
+       cvta.to.global.u64      %rd1, %rd4;
+       mul.lo.s32      %r1, %r14, %r13;
+       shl.b32         %r2, %r1, 2;
+       shl.b32         %r3, %r14, 2;
+       mov.u32         %r15, %ntid.x;
+       mov.u32         %r16, %ctaid.x;
+       mov.u32         %r17, %tid.x;
+       mad.lo.s32      %r4, %r15, %r16, %r17;
+       setp.lt.s32     %p1, %r4, %r2;
+       @%p1 bra        BB100_5;
+       bra.uni         BB100_1;
+
+BB100_5:
+       rem.s32         %r42, %r4, %r1;
+       div.s32         %r43, %r42, %r14;
+       div.s32         %r44, %r4, %r1;
+       setp.lt.s32     %p8, %r44, 2;
+       setp.eq.s32     %p9, %r44, 2;
+       selp.b32        %r45, 3, 2, %p9;
+       selp.b32        %r46, %r44, %r45, %p8;
+       rem.s32         %r47, %r42, %r14;
+       sub.s32         %r48, %r4, %r42;
+       add.s32         %r49, %r48, %r43;
+       mad.lo.s32      %r52, %r47, %r13, %r49;
+       mad.lo.s32      %r50, %r43, %r3, %r47;
+       mad.lo.s32      %r51, %r46, %r14, %r50;
+       bra.uni         BB100_6;
+
+BB100_1:
+       add.s32         %r5, %r14, %r13;
+       mul.lo.s32      %r6, %r5, %r3;
+       setp.lt.s32     %p2, %r4, %r6;
+       @%p2 bra        BB100_4;
+       bra.uni         BB100_2;
+
+BB100_4:
+       mul.lo.s32      %r30, %r14, %r14;
+       sub.s32         %r31, %r4, %r2;
+       rem.s32         %r32, %r31, %r30;
+       div.s32         %r33, %r32, %r14;
+       add.s32         %r34, %r33, %r13;
+       div.s32         %r35, %r31, %r30;
+       setp.lt.s32     %p6, %r35, 2;
+       setp.eq.s32     %p7, %r35, 2;
+       selp.b32        %r36, 3, 2, %p7;
+       selp.b32        %r37, %r35, %r36, %p6;
+       rem.s32         %r38, %r32, %r14;
+       sub.s32         %r39, %r4, %r32;
+       add.s32         %r40, %r39, %r33;
+       mad.lo.s32      %r52, %r38, %r14, %r40;
+       mad.lo.s32      %r41, %r34, %r3, %r38;
+       mad.lo.s32      %r51, %r37, %r14, %r41;
+       bra.uni         BB100_6;
+
+BB100_2:
+       add.s32         %r20, %r5, 1;
+       mul.lo.s32      %r21, %r20, %r3;
+       mov.u32         %r51, -1;
+       setp.ge.s32     %p3, %r4, %r21;
+       @%p3 bra        BB100_6;
+
+       cvta.to.global.u64      %rd5, %rd3;
+       sub.s32         %r24, %r4, %r6;
+       div.s32         %r25, %r24, %r14;
+       setp.lt.s32     %p4, %r25, 2;
+       setp.eq.s32     %p5, %r25, 2;
+       selp.b32        %r26, 3, 2, %p5;
+       selp.b32        %r27, %r25, %r26, %p4;
+       rem.s32         %r28, %r24, %r14;
+       mad.lo.s32      %r29, %r27, %r14, %r28;
+       mul.wide.s32    %rd6, %r29, 8;
+       add.s64         %rd7, %rd5, %rd6;
+       ld.global.f64   %fd1, [%rd7];
+       mul.wide.s32    %rd8, %r4, 8;
+       add.s64         %rd9, %rd1, %rd8;
+       st.global.f64   [%rd9], %fd1;
+
+BB100_6:
+       setp.eq.s32     %p10, %r51, -1;
+       @%p10 bra       BB100_8;
+
+       cvta.to.global.u64      %rd10, %rd2;
+       mul.wide.s32    %rd11, %r51, 8;
+       add.s64         %rd12, %rd10, %rd11;
+       ld.global.f64   %fd2, [%rd12];
+       mul.wide.s32    %rd13, %r52, 8;
+       add.s64         %rd14, %rd1, %rd13;
+       st.global.f64   [%rd14], %fd2;
+
+BB100_8:
+       ret;
+}
+
+       // .globl       prepare_lstm_weight_f
+.visible .entry prepare_lstm_weight_f(
+       .param .u64 prepare_lstm_weight_f_param_0,
+       .param .u64 prepare_lstm_weight_f_param_1,
+       .param .u64 prepare_lstm_weight_f_param_2,
+       .param .u32 prepare_lstm_weight_f_param_3,
+       .param .u32 prepare_lstm_weight_f_param_4
+)
+{
+       .reg .pred      %p<11>;
+       .reg .f32       %f<3>;
+       .reg .b32       %r<53>;
+       .reg .b64       %rd<15>;
+
+
+       ld.param.u64    %rd2, [prepare_lstm_weight_f_param_0];
+       ld.param.u64    %rd3, [prepare_lstm_weight_f_param_1];
+       ld.param.u64    %rd4, [prepare_lstm_weight_f_param_2];
+       ld.param.u32    %r13, [prepare_lstm_weight_f_param_3];
+       ld.param.u32    %r14, [prepare_lstm_weight_f_param_4];
+       cvta.to.global.u64      %rd1, %rd4;
+       mul.lo.s32      %r1, %r14, %r13;
+       shl.b32         %r2, %r1, 2;
+       shl.b32         %r3, %r14, 2;
+       mov.u32         %r15, %ntid.x;
+       mov.u32         %r16, %ctaid.x;
+       mov.u32         %r17, %tid.x;
+       mad.lo.s32      %r4, %r15, %r16, %r17;
+       setp.lt.s32     %p1, %r4, %r2;
+       @%p1 bra        BB101_5;
+       bra.uni         BB101_1;
+
+BB101_5:
+       rem.s32         %r42, %r4, %r1;
+       div.s32         %r43, %r42, %r14;
+       div.s32         %r44, %r4, %r1;
+       setp.lt.s32     %p8, %r44, 2;
+       setp.eq.s32     %p9, %r44, 2;
+       selp.b32        %r45, 3, 2, %p9;
+       selp.b32        %r46, %r44, %r45, %p8;
+       rem.s32         %r47, %r42, %r14;
+       sub.s32         %r48, %r4, %r42;
+       add.s32         %r49, %r48, %r43;
+       mad.lo.s32      %r52, %r47, %r13, %r49;
+       mad.lo.s32      %r50, %r43, %r3, %r47;
+       mad.lo.s32      %r51, %r46, %r14, %r50;
+       bra.uni         BB101_6;
+
+BB101_1:
+       add.s32         %r5, %r14, %r13;
+       mul.lo.s32      %r6, %r5, %r3;
+       setp.lt.s32     %p2, %r4, %r6;
+       @%p2 bra        BB101_4;
+       bra.uni         BB101_2;
+
+BB101_4:
+       mul.lo.s32      %r30, %r14, %r14;
+       sub.s32         %r31, %r4, %r2;
+       rem.s32         %r32, %r31, %r30;
+       div.s32         %r33, %r32, %r14;
+       add.s32         %r34, %r33, %r13;
+       div.s32         %r35, %r31, %r30;
+       setp.lt.s32     %p6, %r35, 2;
+       setp.eq.s32     %p7, %r35, 2;
+       selp.b32        %r36, 3, 2, %p7;
+       selp.b32        %r37, %r35, %r36, %p6;
+       rem.s32         %r38, %r32, %r14;
+       sub.s32         %r39, %r4, %r32;
+       add.s32         %r40, %r39, %r33;
+       mad.lo.s32      %r52, %r38, %r14, %r40;
+       mad.lo.s32      %r41, %r34, %r3, %r38;
+       mad.lo.s32      %r51, %r37, %r14, %r41;
+       bra.uni         BB101_6;
+
+BB101_2:
+       add.s32         %r20, %r5, 1;
+       mul.lo.s32      %r21, %r20, %r3;
+       mov.u32         %r51, -1;
+       setp.ge.s32     %p3, %r4, %r21;
+       @%p3 bra        BB101_6;
+
+       cvta.to.global.u64      %rd5, %rd3;
+       sub.s32         %r24, %r4, %r6;
+       div.s32         %r25, %r24, %r14;
+       setp.lt.s32     %p4, %r25, 2;
+       setp.eq.s32     %p5, %r25, 2;
+       selp.b32        %r26, 3, 2, %p5;
+       selp.b32        %r27, %r25, %r26, %p4;
+       rem.s32         %r28, %r24, %r14;
+       mad.lo.s32      %r29, %r27, %r14, %r28;
+       mul.wide.s32    %rd6, %r29, 4;
+       add.s64         %rd7, %rd5, %rd6;
+       ld.global.f32   %f1, [%rd7];
+       mul.wide.s32    %rd8, %r4, 4;
+       add.s64         %rd9, %rd1, %rd8;
+       st.global.f32   [%rd9], %f1;
+
+BB101_6:
+       setp.eq.s32     %p10, %r51, -1;
+       @%p10 bra       BB101_8;
+
+       cvta.to.global.u64      %rd10, %rd2;
+       mul.wide.s32    %rd11, %r51, 4;
+       add.s64         %rd12, %rd10, %rd11;
+       ld.global.f32   %f2, [%rd12];
+       mul.wide.s32    %rd13, %r52, 4;
+       add.s64         %rd14, %rd1, %rd13;
+       st.global.f32   [%rd14], %f2;
+
+BB101_8:
+       ret;
+}
+
        // .globl       compute_nnz_d
 .visible .entry compute_nnz_d(
        .param .u64 compute_nnz_d_param_0,
@@ -11712,9 +12034,9 @@ BB97_5:
        mad.lo.s32      %r35, %r9, %r10, %r7;
        mov.f64         %fd46, 0d0000000000000000;
        setp.ge.u32     %p1, %r35, %r6;
-       @%p1 bra        BB98_4;
+       @%p1 bra        BB102_4;
 
-BB98_1:
+BB102_1:
        cvta.to.global.u64      %rd3, %rd1;
        mul.wide.u32    %rd4, %r35, 8;
        add.s64         %rd5, %rd3, %rd4;
@@ -11724,7 +12046,7 @@ BB98_1:
        add.f64         %fd46, %fd46, %fd31;
        add.s32         %r3, %r35, %r10;
        setp.ge.u32     %p3, %r3, %r6;
-       @%p3 bra        BB98_3;
+       @%p3 bra        BB102_3;
 
        mul.wide.u32    %rd7, %r3, 8;
        add.s64         %rd8, %rd3, %rd7;
@@ -11733,128 +12055,128 @@ BB98_1:
        selp.f64        %fd33, 0d3FF0000000000000, 0d0000000000000000, %p4;
        add.f64         %fd46, %fd46, %fd33;
 
-BB98_3:
+BB102_3:
        shl.b32         %r13, %r10, 1;
        mov.u32         %r14, %nctaid.x;
        mad.lo.s32      %r35, %r13, %r14, %r35;
        setp.lt.u32     %p5, %r35, %r6;
-       @%p5 bra        BB98_1;
+       @%p5 bra        BB102_1;
 
-BB98_4:
+BB102_4:
        shl.b32         %r16, %r7, 3;
        mov.u32         %r17, my_sdata;
        add.s32         %r5, %r17, %r16;
        st.shared.f64   [%r5], %fd46;
        bar.sync        0;
        setp.lt.u32     %p6, %r10, 1024;
-       @%p6 bra        BB98_8;
+       @%p6 bra        BB102_8;
 
        setp.gt.u32     %p7, %r7, 511;
-       @%p7 bra        BB98_7;
+       @%p7 bra        BB102_7;
 
        ld.shared.f64   %fd34, [%r5+4096];
        add.f64         %fd46, %fd46, %fd34;
        st.shared.f64   [%r5], %fd46;
 
-BB98_7:
+BB102_7:
        bar.sync        0;
 
-BB98_8:
+BB102_8:
        setp.lt.u32     %p8, %r10, 512;
-       @%p8 bra        BB98_12;
+       @%p8 bra        BB102_12;
 
        setp.gt.u32     %p9, %r7, 255;
-       @%p9 bra        BB98_11;
+       @%p9 bra        BB102_11;
 
        ld.shared.f64   %fd35, [%r5+2048];
        add.f64         %fd46, %fd46, %fd35;
        st.shared.f64   [%r5], %fd46;
 
-BB98_11:
+BB102_11:
        bar.sync        0;
 
-BB98_12:
+BB102_12:
        setp.lt.u32     %p10, %r10, 256;
-       @%p10 bra       BB98_16;
+       @%p10 bra       BB102_16;
 
        setp.gt.u32     %p11, %r7, 127;
-       @%p11 bra       BB98_15;
+       @%p11 bra       BB102_15;
 
        ld.shared.f64   %fd36, [%r5+1024];
        add.f64         %fd46, %fd46, %fd36;
        st.shared.f64   [%r5], %fd46;
 
-BB98_15:
+BB102_15:
        bar.sync        0;
 
-BB98_16:
+BB102_16:
        setp.lt.u32     %p12, %r10, 128;
-       @%p12 bra       BB98_20;
+       @%p12 bra       BB102_20;
 
        setp.gt.u32     %p13, %r7, 63;
-       @%p13 bra       BB98_19;
+       @%p13 bra       BB102_19;
 
        ld.shared.f64   %fd37, [%r5+512];
        add.f64         %fd46, %fd46, %fd37;
        st.shared.f64   [%r5], %fd46;
 
-BB98_19:
+BB102_19:
        bar.sync        0;
 
-BB98_20:
+BB102_20:
        setp.gt.u32     %p14, %r7, 31;
-       @%p14 bra       BB98_33;
+       @%p14 bra       BB102_33;
 
        setp.lt.u32     %p15, %r10, 64;
-       @%p15 bra       BB98_23;
+       @%p15 bra       BB102_23;
 
        ld.volatile.shared.f64  %fd38, [%r5+256];
        add.f64         %fd46, %fd46, %fd38;
        st.volatile.shared.f64  [%r5], %fd46;
 
-BB98_23:
+BB102_23:
        setp.lt.u32     %p16, %r10, 32;
-       @%p16 bra       BB98_25;
+       @%p16 bra       BB102_25;
 
        ld.volatile.shared.f64  %fd39, [%r5+128];
        add.f64         %fd46, %fd46, %fd39;
        st.volatile.shared.f64  [%r5], %fd46;
 
-BB98_25:
+BB102_25:
        setp.lt.u32     %p17, %r10, 16;
-       @%p17 bra       BB98_27;
+       @%p17 bra       BB102_27;
 
        ld.volatile.shared.f64  %fd40, [%r5+64];
        add.f64         %fd46, %fd46, %fd40;
        st.volatile.shared.f64  [%r5], %fd46;
 
-BB98_27:
+BB102_27:
        setp.lt.u32     %p18, %r10, 8;
-       @%p18 bra       BB98_29;
+       @%p18 bra       BB102_29;
 
        ld.volatile.shared.f64  %fd41, [%r5+32];
        add.f64         %fd46, %fd46, %fd41;
        st.volatile.shared.f64  [%r5], %fd46;
 
-BB98_29:
+BB102_29:
        setp.lt.u32     %p19, %r10, 4;
-       @%p19 bra       BB98_31;
+       @%p19 bra       BB102_31;
 
        ld.volatile.shared.f64  %fd42, [%r5+16];
        add.f64         %fd46, %fd46, %fd42;
        st.volatile.shared.f64  [%r5], %fd46;
 
-BB98_31:
+BB102_31:
        setp.lt.u32     %p20, %r10, 2;
-       @%p20 bra       BB98_33;
+       @%p20 bra       BB102_33;
 
        ld.volatile.shared.f64  %fd43, [%r5+8];
        add.f64         %fd44, %fd46, %fd43;
        st.volatile.shared.f64  [%r5], %fd44;
 
-BB98_33:
+BB102_33:
        setp.ne.s32     %p21, %r7, 0;
-       @%p21 bra       BB98_35;
+       @%p21 bra       BB102_35;
 
        ld.shared.f64   %fd45, [my_sdata];
        cvta.to.global.u64      %rd9, %rd2;
@@ -11862,7 +12184,7 @@ BB98_33:
        add.s64         %rd11, %rd9, %rd10;
        st.global.f64   [%rd11], %fd45;
 
-BB98_35:
+BB102_35:
        ret;
 }
 
@@ -11889,9 +12211,9 @@ BB98_35:
        mad.lo.s32      %r35, %r9, %r10, %r7;
        mov.f32         %f46, 0f00000000;
        setp.ge.u32     %p1, %r35, %r6;
-       @%p1 bra        BB99_4;
+       @%p1 bra        BB103_4;
 
-BB99_1:
+BB103_1:
        cvta.to.global.u64      %rd3, %rd1;
        mul.wide.u32    %rd4, %r35, 4;
        add.s64         %rd5, %rd3, %rd4;
@@ -11901,7 +12223,7 @@ BB99_1:
        add.f32         %f46, %f46, %f31;
        add.s32         %r3, %r35, %r10;
        setp.ge.u32     %p3, %r3, %r6;
-       @%p3 bra        BB99_3;
+       @%p3 bra        BB103_3;
 
        mul.wide.u32    %rd7, %r3, 4;
        add.s64         %rd8, %rd3, %rd7;
@@ -11910,128 +12232,128 @@ BB99_1:
        selp.f32        %f33, 0f3F800000, 0f00000000, %p4;
        add.f32         %f46, %f46, %f33;
 
-BB99_3:
+BB103_3:
        shl.b32         %r13, %r10, 1;
        mov.u32         %r14, %nctaid.x;
        mad.lo.s32      %r35, %r13, %r14, %r35;
        setp.lt.u32     %p5, %r35, %r6;
-       @%p5 bra        BB99_1;
+       @%p5 bra        BB103_1;
 
-BB99_4:
+BB103_4:
        shl.b32         %r16, %r7, 2;
        mov.u32         %r17, my_sdata;
        add.s32         %r5, %r17, %r16;
        st.shared.f32   [%r5], %f46;
        bar.sync        0;
        setp.lt.u32     %p6, %r10, 1024;
-       @%p6 bra        BB99_8;
+       @%p6 bra        BB103_8;
 
        setp.gt.u32     %p7, %r7, 511;
-       @%p7 bra        BB99_7;
+       @%p7 bra        BB103_7;
 
        ld.shared.f32   %f34, [%r5+2048];
        add.f32         %f46, %f46, %f34;
        st.shared.f32   [%r5], %f46;
 
-BB99_7:
+BB103_7:
        bar.sync        0;
 
-BB99_8:
+BB103_8:
        setp.lt.u32     %p8, %r10, 512;
-       @%p8 bra        BB99_12;
+       @%p8 bra        BB103_12;
 
        setp.gt.u32     %p9, %r7, 255;
-       @%p9 bra        BB99_11;
+       @%p9 bra        BB103_11;
 
        ld.shared.f32   %f35, [%r5+1024];
        add.f32         %f46, %f46, %f35;
        st.shared.f32   [%r5], %f46;
 
-BB99_11:
+BB103_11:
        bar.sync        0;
 
-BB99_12:
+BB103_12:
        setp.lt.u32     %p10, %r10, 256;
-       @%p10 bra       BB99_16;
+       @%p10 bra       BB103_16;
 
        setp.gt.u32     %p11, %r7, 127;
-       @%p11 bra       BB99_15;
+       @%p11 bra       BB103_15;
 
        ld.shared.f32   %f36, [%r5+512];
        add.f32         %f46, %f46, %f36;
        st.shared.f32   [%r5], %f46;
 
-BB99_15:
+BB103_15:
        bar.sync        0;
 
-BB99_16:
+BB103_16:
        setp.lt.u32     %p12, %r10, 128;
-       @%p12 bra       BB99_20;
+       @%p12 bra       BB103_20;
 
        setp.gt.u32     %p13, %r7, 63;
-       @%p13 bra       BB99_19;
+       @%p13 bra       BB103_19;
 
        ld.shared.f32   %f37, [%r5+256];
        add.f32         %f46, %f46, %f37;
        st.shared.f32   [%r5], %f46;
 
-BB99_19:
+BB103_19:
        bar.sync        0;
 
-BB99_20:
+BB103_20:
        setp.gt.u32     %p14, %r7, 31;
-       @%p14 bra       BB99_33;
+       @%p14 bra       BB103_33;
 
        setp.lt.u32     %p15, %r10, 64;
-       @%p15 bra       BB99_23;
+       @%p15 bra       BB103_23;
 
        ld.volatile.shared.f32  %f38, [%r5+128];
        add.f32         %f46, %f46, %f38;
        st.volatile.shared.f32  [%r5], %f46;
 
-BB99_23:
+BB103_23:
        setp.lt.u32     %p16, %r10, 32;
-       @%p16 bra       BB99_25;
+       @%p16 bra       BB103_25;
 
        ld.volatile.shared.f32  %f39, [%r5+64];
        add.f32         %f46, %f46, %f39;
        st.volatile.shared.f32  [%r5], %f46;
 
-BB99_25:
+BB103_25:
        setp.lt.u32     %p17, %r10, 16;
-       @%p17 bra       BB99_27;
+       @%p17 bra       BB103_27;
 
        ld.volatile.shared.f32  %f40, [%r5+32];
        add.f32         %f46, %f46, %f40;
        st.volatile.shared.f32  [%r5], %f46;
 
-BB99_27:
+BB103_27:
        setp.lt.u32     %p18, %r10, 8;
-       @%p18 bra       BB99_29;
+       @%p18 bra       BB103_29;
 
        ld.volatile.shared.f32  %f41, [%r5+16];
        add.f32         %f46, %f46, %f41;
        st.volatile.shared.f32  [%r5], %f46;
 
-BB99_29:
+BB103_29:
        setp.lt.u32     %p19, %r10, 4;
-       @%p19 bra       BB99_31;
+       @%p19 bra       BB103_31;
 
        ld.volatile.shared.f32  %f42, [%r5+8];
        add.f32         %f46, %f46, %f42;
        st.volatile.shared.f32  [%r5], %f46;
 
-BB99_31:
+BB103_31:
        setp.lt.u32     %p20, %r10, 2;
-       @%p20 bra       BB99_33;
+       @%p20 bra       BB103_33;
 
        ld.volatile.shared.f32  %f43, [%r5+4];
        add.f32         %f44, %f46, %f43;
        st.volatile.shared.f32  [%r5], %f44;
 
-BB99_33:
+BB103_33:
        setp.ne.s32     %p21, %r7, 0;
-       @%p21 bra       BB99_35;
+       @%p21 bra       BB103_35;
 
        ld.shared.f32   %f45, [my_sdata];
        cvta.to.global.u64      %rd9, %rd2;
@@ -12039,7 +12361,105 @@ BB99_33:
        add.s64         %rd11, %rd9, %rd10;
        st.global.f32   [%rd11], %f45;
 
-BB99_35:
+BB103_35:
+       ret;
+}
+
+       // .globl       prepare_lstm_output_d
+.visible .entry prepare_lstm_output_d(
+       .param .u64 prepare_lstm_output_d_param_0,
+       .param .u64 prepare_lstm_output_d_param_1,
+       .param .u32 prepare_lstm_output_d_param_2,
+       .param .u32 prepare_lstm_output_d_param_3,
+       .param .u32 prepare_lstm_output_d_param_4,
+       .param .u32 prepare_lstm_output_d_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<16>;
+       .reg .f64       %fd<2>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_output_d_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_output_d_param_1];
+       ld.param.u32    %r2, [prepare_lstm_output_d_param_2];
+       ld.param.u32    %r3, [prepare_lstm_output_d_param_3];
+       ld.param.u32    %r4, [prepare_lstm_output_d_param_4];
+       ld.param.u32    %r5, [prepare_lstm_output_d_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB104_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       mul.lo.s32      %r9, %r4, %r3;
+       div.s32         %r10, %r1, %r9;
+       rem.s32         %r11, %r1, %r9;
+       div.s32         %r12, %r11, %r4;
+       rem.s32         %r13, %r11, %r4;
+       mad.lo.s32      %r14, %r12, %r2, %r10;
+       mad.lo.s32      %r15, %r14, %r4, %r13;
+       mul.wide.s32    %rd4, %r15, 8;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f64   %fd1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd1;
+       mul.wide.s32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f64   [%rd8], %fd1;
+
+BB104_2:
+       ret;
+}
+
+       // .globl       prepare_lstm_output_f
+.visible .entry prepare_lstm_output_f(
+       .param .u64 prepare_lstm_output_f_param_0,
+       .param .u64 prepare_lstm_output_f_param_1,
+       .param .u32 prepare_lstm_output_f_param_2,
+       .param .u32 prepare_lstm_output_f_param_3,
+       .param .u32 prepare_lstm_output_f_param_4,
+       .param .u32 prepare_lstm_output_f_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<2>;
+       .reg .b32       %r<16>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_output_f_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_output_f_param_1];
+       ld.param.u32    %r2, [prepare_lstm_output_f_param_2];
+       ld.param.u32    %r3, [prepare_lstm_output_f_param_3];
+       ld.param.u32    %r4, [prepare_lstm_output_f_param_4];
+       ld.param.u32    %r5, [prepare_lstm_output_f_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB105_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       mul.lo.s32      %r9, %r4, %r3;
+       div.s32         %r10, %r1, %r9;
+       rem.s32         %r11, %r1, %r9;
+       div.s32         %r12, %r11, %r4;
+       rem.s32         %r13, %r11, %r4;
+       mad.lo.s32      %r14, %r12, %r2, %r10;
+       mad.lo.s32      %r15, %r14, %r4, %r13;
+       mul.wide.s32    %rd4, %r15, 4;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f32   %f1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd1;
+       mul.wide.s32    %rd7, %r1, 4;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f32   [%rd8], %f1;
+
+BB105_2:
        ret;
 }
 
@@ -12048,7 +12468,7 @@ BB99_35:
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot100[40];
+       .local .align 8 .b8     __local_depot106[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -12057,7 +12477,7 @@ BB99_35:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot100;
+       mov.u64         %rd101, __local_depot106;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -12071,7 +12491,7 @@ BB99_35:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB100_13;
+       @%p1 bra        BB106_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -12084,7 +12504,7 @@ BB99_35:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB100_4;
+       @%p2 bra        BB106_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -12101,7 +12521,7 @@ BB99_35:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB100_3:
+BB106_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -12131,15 +12551,15 @@ BB100_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB100_3;
+       @%p3 bra        BB106_3;
 
-BB100_4:
+BB106_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB100_6;
+       @%p4 bra        BB106_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -12151,7 +12571,7 @@ BB100_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB100_6:
+BB106_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -12168,7 +12588,7 @@ BB100_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB100_8;
+       @%p6 bra        BB106_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -12188,10 +12608,10 @@ BB100_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB100_8:
+BB106_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB100_10;
+       @%p7 bra        BB106_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -12199,7 +12619,7 @@ BB100_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB100_10:
+BB106_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -12220,7 +12640,7 @@ BB100_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB100_12;
+       @%p8 bra        BB106_12;
 
        // inline asm
        {
@@ -12239,7 +12659,7 @@ BB100_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB100_12:
+BB106_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -12254,7 +12674,7 @@ BB100_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB100_13:
+BB106_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -12282,7 +12702,7 @@ BB100_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB101_2;
+       @%p1 bra        BB107_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -12296,13 +12716,13 @@ BB100_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB101_2:
+BB107_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB101_4;
+       @%p2 bra        BB107_4;
 
        {
        .reg .b32 %temp; 
@@ -12316,7 +12736,7 @@ BB101_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB101_4:
+BB107_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -12479,13 +12899,13 @@ BB101_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB101_7;
+       @%p4 bra        BB107_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB101_7;
+       @%p6 bra        BB107_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -12507,26 +12927,26 @@ BB101_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB101_7:
+BB107_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB101_9;
+       @%p7 bra        BB107_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB101_10;
+       @%p8 bra        BB107_10;
 
-BB101_9:
+BB107_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB101_10:
+BB107_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index a405fd9..f3303de 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -248,7 +248,7 @@ public class ScriptExecutor {
 
                // set the GPUs to use for this process (a range, all GPUs, 
comma separated list or a specific GPU)
                GPUContextPool.AVAILABLE_GPUS = 
config.getTextValue(DMLConfig.AVAILABLE_GPUS);
-
+               
                String evictionPolicy = 
config.getTextValue(DMLConfig.GPU_EVICTION_POLICY).toUpperCase();
                try {
                        DMLScript.GPU_EVICTION_POLICY = 
EvictionPolicy.valueOf(evictionPolicy);

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java 
b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 0a896a3..9f08c3c 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -82,19 +82,19 @@ public class DMLConfig
        public static final String CODEGEN_PLANCACHE    = 
"sysml.codegen.plancache"; //boolean
        public static final String CODEGEN_LITERALS     = 
"sysml.codegen.literals"; //1..heuristic, 2..always
        public static final String CACHING_BUFFER_SIZE  = 
"sysml.caching.bufferSize"; //double: default:0.15
-       
        public static final String EXTRA_FINEGRAINED_STATS = 
"sysml.stats.finegrained"; //boolean
        public static final String STATS_MAX_WRAP_LEN   = 
"sysml.stats.maxWrapLength"; //int
        public static final String AVAILABLE_GPUS       = 
"sysml.gpu.availableGPUs"; // String to specify which GPUs to use (a range, all 
GPUs, comma separated list or a specific GPU)
        public static final String SYNCHRONIZE_GPU      = 
"sysml.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after 
every instruction 
        public static final String EAGER_CUDA_FREE              = 
"sysml.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on 
rmvar
        public static final String GPU_EVICTION_POLICY  = 
"sysml.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
+       
        // Fraction of available memory to use. The available memory is 
computer when the GPUContext is created
        // to handle the tradeoff on calling cudaMemGetInfo too often.
        public static final String GPU_MEMORY_UTILIZATION_FACTOR = 
"sysml.gpu.memory.util.factor";
        public static final String FLOATING_POINT_PRECISION = 
"sysml.floating.point.precision"; // String to specify the datatype to use 
internally: supported values are double, single
        public static final String PRINT_GPU_MEMORY_INFO = 
"sysml.gpu.print.memoryInfo";
-       
+
        // supported prefixes for custom map/reduce configurations
        public static final String PREFIX_MAPRED = "mapred";
        public static final String PREFIX_MAPREDUCE = "mapreduce";
@@ -135,13 +135,13 @@ public class DMLConfig
                _defaultVals.put(NATIVE_BLAS,            "none" );
                _defaultVals.put(NATIVE_BLAS_DIR,        "none" );
                _defaultVals.put(EXTRA_FINEGRAINED_STATS,"false" );
+               _defaultVals.put(PRINT_GPU_MEMORY_INFO,  "false" );
                _defaultVals.put(STATS_MAX_WRAP_LEN,     "30" );
                _defaultVals.put(GPU_MEMORY_UTILIZATION_FACTOR,      "0.9" );
                _defaultVals.put(AVAILABLE_GPUS,         "-1");
                _defaultVals.put(GPU_EVICTION_POLICY,    "align_memory");
                _defaultVals.put(SYNCHRONIZE_GPU,        "false" );
                _defaultVals.put(CACHING_BUFFER_SIZE,    "0.15" );
-               _defaultVals.put(SYNCHRONIZE_GPU,        "true" );
                _defaultVals.put(EAGER_CUDA_FREE,        "false" );
                _defaultVals.put(FLOATING_POINT_PRECISION,               
"double" );
                _defaultVals.put(PRINT_GPU_MEMORY_INFO,  "false");

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 3631f00..ec2fda8 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -168,6 +168,10 @@ public class FunctionOp extends Hop
                                long outputValues = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
                                return outputVectors+outputValues; 
                        }
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
+                               // TODO: To allow for initial version to always 
run on the GPU
+                               return 0; 
+                       }
                        else if ( getFunctionName().equalsIgnoreCase("svd") ) {
                                long outputU = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
                                long outputSigma = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0);
@@ -198,6 +202,10 @@ public class FunctionOp extends Hop
                                return 
OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0) 
                                                + 
3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 
1.0); 
                        }
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                               // TODO: To allow for initial version to always 
run on the GPU
+                               return 0; 
+                       }
                        else if ( getFunctionName().equalsIgnoreCase("svd")) {
                                double interOutput = 
OptimizerUtils.estimateSizeExactSparsity(1, getInput().get(0).getDim2(), 1.0);
                                return interOutput;
@@ -215,7 +223,10 @@ public class FunctionOp extends Hop
        
        @Override
        public boolean isGPUEnabled() {
-               return false;
+               if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) 
+                       return true;
+               else
+                       return false;
        }
        
        @Override
@@ -253,6 +264,7 @@ public class FunctionOp extends Hop
        protected ExecType optFindExecType() 
        {
                checkAndSetForcedPlatform();
+               ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? 
ExecType.SPARK : ExecType.MR;
                
                if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
                        
@@ -261,7 +273,17 @@ public class FunctionOp extends Hop
                                _etype = ((_etypeForced==ExecType.SPARK 
                                        || (getMemEstimate() >= 
OptimizerUtils.getLocalMemBudget()
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
-                       }       
+                       }
+                       else if( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+//                             if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+//                                     _etype = findExecTypeByMemEstimate();
+//                             }
+//                             else {
+//                                     _etype = ExecType.CP;
+//                             }
+//                             _etype = _etype == REMOTE ?  ExecType.CP : 
_etype; // lstm not supported on Spark
+                               _etype = ExecType.GPU;
+                       }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw
                                // exception if the estimated memory is larger 
than the budget

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java 
b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index 70fb797..d01ed09 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -27,7 +27,6 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Group;
 import org.apache.sysml.lops.Lop;
-import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.SortKeys;
 import org.apache.sysml.lops.Transform;
 import org.apache.sysml.lops.LopProperties.ExecType;
@@ -132,18 +131,18 @@ public class ReorgOp extends Hop implements 
MultiThreadedHop
                        return false;
                switch( op ) {
                        case TRANS: {
-                               Lop lin;
-                               try {
-                                       lin = getInput().get(0).constructLops();
-                               } catch (HopsException | LopsException e) {
-                                       throw new RuntimeException("Unable to 
create child lop", e);
+                               if( getDim1()==1 && getDim2()==1 ) {
+                                       return false; //if input of size 1x1, 
avoid unnecessary transpose
                                }
-                               if( lin instanceof Transform && 
((Transform)lin).getOperationType()==OperationTypes.Transpose )
+                               else if( getInput().get(0) instanceof ReorgOp 
&&  ((ReorgOp) getInput().get(0)).getOp() == ReOrgOp.TRANS) {
+                                       // Following checks causes 
stackoverflow:
+                                       // lin = 
getInput().get(0).constructLops();
+                                       // lin instanceof Transform && 
((Transform)lin).getOperationType()==OperationTypes.Transpose
                                        return false; //if input is already a 
transpose, avoid redundant transpose ops
-                               else if( getDim1()==1 && getDim2()==1 )
-                                       return false; //if input of size 1x1, 
avoid unnecessary transpose
-                               else
+                               }
+                               else {
                                        return true;
+                               }
                        }
                        case DIAG:
                        case REV:

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index ea51bd1..8a0a6d8 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -89,6 +89,18 @@ public class BuiltinFunctionExpression extends DataIdentifier
        public Expression getThirdExpr() {
                return (_args.length >= 3 ? _args[2] : null);
        }
+       
+       public Expression getFourthExpr() {
+               return (_args.length >= 4 ? _args[3] : null);
+       }
+       
+       public Expression getFifthExpr() {
+               return (_args.length >= 5 ? _args[4] : null);
+       }
+       
+       public Expression getSixthExpr() {
+               return (_args.length >= 5 ? _args[4] : null);
+       }
 
        public Expression[] getAllExpr(){
                return _args;
@@ -106,17 +118,14 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                }
                
                this.getFirstExpr().validateExpression(ids, constVars, 
conditional);
-               if (getSecondExpr() != null){
-                       if (this.getSecondExpr() instanceof 
FunctionCallIdentifier){
-                               raiseValidateError("UDF function call not 
supported as parameter to built-in function call", false);
-                       }
-                       getSecondExpr().validateExpression(ids, constVars, 
conditional);
-               }
-               if (getThirdExpr() != null) {
-                       if (this.getThirdExpr() instanceof 
FunctionCallIdentifier){
-                               raiseValidateError("UDF function call not 
supported as parameter to built-in function call", false);
+               Expression [] expr = getAllExpr();
+               if(expr != null && expr.length > 1) {
+                       for(int i = 1; i < expr.length; i++) {
+                               if (expr[i] instanceof FunctionCallIdentifier){
+                                       raiseValidateError("UDF function call 
not supported as parameter to built-in function call", false);
+                               }
+                               expr[i].validateExpression(ids, constVars, 
conditional);
                        }
-                       getThirdExpr().validateExpression(ids, constVars, 
conditional);
                }
                _outputs = new Identifier[stmt.getTargetList().size()];
                int count = 0;
@@ -189,6 +198,97 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        
                        break;
 
+               case LSTM:
+               {
+                       // X,  W, bias, out0, c0, return_sequences
+                       checkNumParameters(6);
+                       checkMatrixParam(getFirstExpr());
+                       checkMatrixParam(getSecondExpr());
+                       checkMatrixParam(getThirdExpr());
+                       checkMatrixParam(getFourthExpr());
+                       checkMatrixParam(getFifthExpr());
+                       
+                       // setup output properties
+                       if(getOutputs() == null || getOutputs().length != 3) {
+                               int numOutputs = getOutputs() == null ? 0 : 
getOutputs().length;
+                               raiseValidateError("The builtin function lstm 
has three outputs, but instead found: " + numOutputs, conditional);
+                       }
+                       DataIdentifier out = (DataIdentifier) getOutputs()[0];
+                       DataIdentifier cy = (DataIdentifier) getOutputs()[1];
+                       DataIdentifier reserveSpace = (DataIdentifier) 
getOutputs()[2];
+                       
+                       // Output1 - out: If `return_sequences` is True, 
outputs for all timesteps, else outputs for the final timestep.
+                       out.setDataType(DataType.MATRIX);
+                       out.setValueType(ValueType.DOUBLE);
+                       out.setDimensions(-1, -1);
+                       
out.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), 
getFirstExpr().getOutput().getColumnsInBlock());
+                       
+                       // Output2 - Cell state for final timestep.
+                       cy.setDataType(DataType.MATRIX);
+                       cy.setValueType(ValueType.DOUBLE);
+                       cy.setDimensions(getExpr(4).getOutput().getDim1(), 
getExpr(4).getOutput().getDim2());
+                       
cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), 
getExpr(4).getOutput().getColumnsInBlock());
+                       
+                       // Output3 - reserve space.
+                       reserveSpace.setDataType(DataType.MATRIX);
+                       reserveSpace.setValueType(ValueType.DOUBLE);
+                       reserveSpace.setDimensions(-1, -1);
+                       
reserveSpace.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), 
getFirstExpr().getOutput().getColumnsInBlock());
+                       
+                       break;
+               }
+               case BATCH_NORM2D:
+               {
+                       // Input: image, scale, bias, runningMean, runningVar, 
mode, epsilon, exponentialAverageFactor
+                       checkNumParameters(8);
+                       checkMatrixParam(getFirstExpr());
+                       checkMatrixParam(getSecondExpr());
+                       checkMatrixParam(getThirdExpr());
+                       checkMatrixParam(getFourthExpr());
+                       checkMatrixParam(getFifthExpr());
+                       
+                       // Output: ret, retRunningMean, retRunningVar, 
resultSaveMean, resultSaveInvVariance
+                       // setup output properties
+                       if(getOutputs().length != 5)
+                               raiseValidateError("batch_norm2d has 5 
outputs", false);
+                        
+                       DataIdentifier ret = (DataIdentifier) getOutputs()[0];
+                       DataIdentifier retRunningMean = (DataIdentifier) 
getOutputs()[1];
+                       DataIdentifier retRunningVar = (DataIdentifier) 
getOutputs()[2];
+                       DataIdentifier resultSaveMean = (DataIdentifier) 
getOutputs()[3];
+                       DataIdentifier resultSaveInvVariance = (DataIdentifier) 
getOutputs()[4];
+                       
+                       setDimensions(ret, getFirstExpr());
+                       setDimensions(retRunningMean, getFourthExpr());
+                       setDimensions(retRunningVar, getFourthExpr());
+                       setDimensions(resultSaveMean, getFourthExpr());
+                       setDimensions(resultSaveInvVariance, getFourthExpr());
+                       break;
+               }
+               case BATCH_NORM2D_BACKWARD:
+               {
+                       // Input: image, dout, scale, epsilon, savedMean, 
savedInvVariance
+                       checkNumParameters(6);
+                       checkMatrixParam(getFirstExpr());
+                       checkMatrixParam(getSecondExpr());
+                       checkMatrixParam(getThirdExpr());
+                       checkMatrixParam(getFifthExpr());
+                       checkMatrixParam(getSixthExpr());
+                       
+                       // Output: dX, dScale, dBias 
+                       // setup output properties
+                       if(getOutputs().length != 3)
+                               raiseValidateError("batch_norm2d_backward has 3 
outputs", false);
+                       
+                       DataIdentifier dX = (DataIdentifier) getOutputs()[0];
+                       DataIdentifier dScale = (DataIdentifier) 
getOutputs()[1];
+                       DataIdentifier dBias = (DataIdentifier) getOutputs()[2];
+                       
+                       setDimensions(dX, getFirstExpr());
+                       setDimensions(dScale, getThirdExpr());
+                       setDimensions(dBias, getThirdExpr());
+                       break;
+               }
                case EIGEN:
                        checkNumParameters(1);
                        checkMatrixParam(getFirstExpr());
@@ -251,6 +351,13 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                }
        }
        
+       private static void setDimensions(DataIdentifier out, Expression exp) {
+               out.setDataType(DataType.MATRIX);
+               out.setValueType(ValueType.DOUBLE);
+               out.setDimensions(exp.getOutput().getDim1(), 
exp.getOutput().getDim2());
+               out.setBlockDimensions(exp.getOutput().getRowsInBlock(), 
exp.getOutput().getColumnsInBlock());
+       }
+       
        private static ArrayList<ParameterExpression> 
orderConvolutionParams(ArrayList<ParameterExpression> paramExpression, int 
skip) {
                ArrayList<ParameterExpression> newParams = new ArrayList<>();
 
@@ -1300,7 +1407,8 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        else {
                                // always unconditional (because unsupported 
operation)
                                BuiltinFunctionOp op = getOpCode();
-                               if( op==BuiltinFunctionOp.EIGEN || 
op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || 
op==BuiltinFunctionOp.SVD)
+                               if( op==BuiltinFunctionOp.EIGEN || 
op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || 
op==BuiltinFunctionOp.SVD 
+                                               || op==BuiltinFunctionOp.LSTM 
|| op==BuiltinFunctionOp.BATCH_NORM2D || 
op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
                                        raiseValidateError("Function "+op+" 
needs to be called with multi-return assignment.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                else
                                        raiseValidateError("Unsupported 
function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
@@ -1362,6 +1470,9 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case QR:
                case LU:
                case EIGEN:
+               case LSTM:
+               case BATCH_NORM2D:
+               case BATCH_NORM2D_BACKWARD:
                case SVD:
                        return true;
                default:
@@ -1503,7 +1614,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
        protected void checkMatrixParam(Expression e) {
                if (e.getOutput().getDataType() != DataType.MATRIX) {
                        raiseValidateError(
-                                       "Expecting matrix argument for function 
" + this.getOpCode().toString().toLowerCase() + "().",
+                                       "Expected " + e.getText() + " to be a 
matrix argument for function " + this.getOpCode().toString().toLowerCase() + 
"().",
                                        false);
                }
        }
@@ -1771,6 +1882,12 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        bifop = Expression.BuiltinFunctionOp.LU;
                else if (functionName.equals("eigen"))
                        bifop = Expression.BuiltinFunctionOp.EIGEN;
+               else if (functionName.equals("lstm"))
+                       bifop = Expression.BuiltinFunctionOp.LSTM;
+               else if (functionName.equals("batch_norm2d"))
+                       bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D;
+               else if (functionName.equals("batch_norm2d_backward"))
+                       bifop = 
Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD;
                else if (functionName.equals("conv2d"))
                         bifop = Expression.BuiltinFunctionOp.CONV2D;
                else if (functionName.equals("bias_add"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index ff02012..fd0b47b 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2212,10 +2212,12 @@ public class DMLTranslator
                // Construct Hops for all inputs
                ArrayList<Hop> inputs = new ArrayList<>();
                inputs.add( processExpression(source.getFirstExpr(), null, 
hops) );
-               if ( source.getSecondExpr() != null )
-                       inputs.add( processExpression(source.getSecondExpr(), 
null, hops) );
-               if ( source.getThirdExpr() != null )
-                       inputs.add( processExpression(source.getThirdExpr(), 
null, hops) );
+               Expression[] expr = source.getAllExpr();
+               if(expr != null && expr.length > 1) {
+                       for(int i = 1; i < expr.length; i++) {
+                               inputs.add( processExpression(expr[i], null, 
hops) );
+                       }
+               }
                
                FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN;
                String nameSpace = DMLProgram.INTERNAL_NAMESPACE;
@@ -2230,6 +2232,9 @@ public class DMLTranslator
                case QR:
                case LU:
                case EIGEN:
+               case LSTM:
+               case BATCH_NORM2D:
+               case BATCH_NORM2D_BACKWARD:
                case SVD:
                        
                        // Number of outputs = size of targetList = #of 
identifiers in source.getOutputs

http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index 9a6ea64..4355bd5 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -92,6 +92,7 @@ public abstract class Expression implements ParseInfo
                EXISTS,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, 
BIAS_MULTIPLY,
                MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
+               LSTM, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
                EXP,
                FLOOR,
                IFELSE,

Reply via email to