[SYSTEMML-445] Added builtin functions for efficient computation of batch normalization and lstm layers.
- Following builtin functions are added: lstm, batch_norm2d and batch_norm2d_backward. - The DML language documentation and the NN layers are also updated. - Since the builtin function for lstm backward data/weights is not added in this commit, the nn layer for lstm is not updated. Instead a new lstm_staging.dml is added, which will eventually replace lstm.dml. - The above builtin functions are only supported on GPU via CuDNN. The CP and Spark implementation will be added in subsequent commits. Closes #773. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/276065f9 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/276065f9 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/276065f9 Branch: refs/heads/master Commit: 276065f93e5051a19d1f86d76b2844d96f7543b3 Parents: cba082e Author: Niketan Pansare <npan...@us.ibm.com> Authored: Fri Jun 1 10:49:46 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Fri Jun 1 10:49:46 2018 -0700 ---------------------------------------------------------------------- docs/dml-language-reference.md | 22 +- scripts/nn/layers/batch_norm2d.dml | 56 +- scripts/nn/layers/batch_norm2d_old.dml | 200 ++++++ scripts/nn/layers/lstm_staging.dml | 216 +++++++ src/main/cpp/kernels/SystemML.cu | 104 ++++ src/main/cpp/kernels/SystemML.ptx | 622 ++++++++++++++++--- .../sysml/api/mlcontext/ScriptExecutor.java | 2 +- .../java/org/apache/sysml/conf/DMLConfig.java | 6 +- .../java/org/apache/sysml/hops/FunctionOp.java | 26 +- .../java/org/apache/sysml/hops/ReorgOp.java | 19 +- .../sysml/parser/BuiltinFunctionExpression.java | 141 ++++- .../org/apache/sysml/parser/DMLTranslator.java | 13 +- .../org/apache/sysml/parser/Expression.java | 1 + .../controlprogram/caching/CacheableData.java | 2 +- .../instructions/GPUInstructionParser.java | 7 +- .../gpu/ConvolutionGPUInstruction.java | 259 +++++++- .../instructions/gpu/context/CSRPointer.java | 2 +- .../instructions/gpu/context/GPUContext.java | 2 +- .../context/GPULazyCudaFreeMemoryManager.java | 2 +- .../gpu/context/GPUMatrixMemoryManager.java | 2 +- .../gpu/context/GPUMemoryManager.java | 2 +- .../instructions/gpu/context/GPUObject.java | 2 +- .../runtime/matrix/data/LibMatrixCUDA.java | 5 +- .../runtime/matrix/data/LibMatrixCuDNN.java | 273 +++++++- .../LibMatrixCuDNNConvolutionAlgorithm.java | 2 +- .../data/LibMatrixCuDNNInputRowFetcher.java | 2 +- .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java | 283 +++++++++ .../runtime/matrix/data/LibMatrixCuMatMult.java | 2 +- .../SinglePrecisionCudaSupportFunctions.java | 2 +- .../org/apache/sysml/utils/GPUStatistics.java | 2 +- .../org/apache/sysml/api/dl/CaffeLayer.scala | 2 +- 31 files changed, 2067 insertions(+), 214 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/docs/dml-language-reference.md ---------------------------------------------------------------------- diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md index b4ed9c8..3212806 100644 --- a/docs/dml-language-reference.md +++ b/docs/dml-language-reference.md @@ -1511,16 +1511,18 @@ The images are assumed to be stored NCHW format, where N = batch size, C = #chan Hence, the images are internally represented as a matrix with dimension (N, C * H * W). -| Function name | Input matrices | Dimension of first input matrix | Dimension of second input matrix (if applicable) | Dimension of output matrix | Input Parameters | Notes | -|---------------------------------------------|----------------|-----------------------------------------------------------|-----------------------------------------------------------|------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| -| conv2d | input, filter | [batch_size X num_channels* height_image* width_image] | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Performs 2D convolution operation | -| conv2d_backward_filter | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels_out* height_out* width_out] | [num_filters X num_channels* height_filter* width_filter] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt filter of 2D convolution | -| conv2d_backward_data | filter, dout | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out] | [batch_size X num_channels* height_image* width_image] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt input of 2D convolution | -| max_pool, avg_pool | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_out* width_out] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool] | Performs max/average pooling operation | -| max_pool_backward, avg_pool_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_out* width_out] | [batch_size X num_channels* height_image* width_image] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool] | Computes the gradients wrt input of 2D max pooling, average pooling | -| bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels | -| bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels | - +| Function name | Input matrices | Dimension of first input matrix | Dimension of second input matrix (if applicable) | Dimension of (first) output matrix | Input Parameters | Notes | +|---------------------------------------------|--------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| +| conv2d | input, filter | [batch_size X num_channels* height_image* width_image] | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Performs 2D convolution operation | +| conv2d_backward_filter | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels_out* height_out* width_out] | [num_filters X num_channels* height_filter* width_filter] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt filter of 2D convolution | +| conv2d_backward_data | filter, dout | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out] | [batch_size X num_channels* height_image* width_image] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt input of 2D convolution | +| max_pool, avg_pool | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_out* width_out] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool] | Performs max/average pooling operation | +| max_pool_backward, avg_pool_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_out* width_out] | [batch_size X num_channels* height_image* width_image] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool] | Computes the gradients wrt input of 2D max pooling, average pooling | +| bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels | +| bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels | +| lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut, reserveSpace) | +| batch_norm2d | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_image* width_image] | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum | Performs batch normalization operation (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance) | +| batch_norm2d_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward) | Computed backpropagation error for batch normalization operation | Examples: http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/batch_norm2d.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml index 8a8555f..2a98857 100644 --- a/scripts/nn/layers/batch_norm2d.dml +++ b/scripts/nn/layers/batch_norm2d.dml @@ -83,41 +83,8 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, * - cache_inv_var: Cache of the inverse variance, of shape (C, 1). * Note: This is used for performance during training. */ - N = nrow(X) - - if (mode == 'train') { - # Compute channel-wise mean and variance - # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion. - # - mean of total group is mean of subgroup means - # - variance is the mean of the subgroup variances + the variance of the subgroup means - subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) - subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances - mean = rowMeans(subgrp_means) # shape (C, 1) - var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1) - # Update moving averages - ema_mean_upd = mu*ema_mean + (1-mu)*mean - ema_var_upd = mu*ema_var + (1-mu)*var - } - else { - # Use moving averages of mean and variance during testing - mean = ema_mean - var = ema_var - ema_mean_upd = ema_mean - ema_var_upd = ema_var - } - - # Save variable for backward pass - cache_mean = mean - cache_inv_var = 1/sqrt(var+epsilon) - - # Normalize, shift, and scale - # norm = (X-mean)*(var+epsilon)^(-1/2) - # = (X-mean) / sqrt(var+epsilon) - centered = bias_add(X, -mean) # shape (N, C*Hin*Win) - norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) - # out = norm*gamma + beta - scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win) - out = bias_add(scaled, beta) # shape (N, C*Hin*Win) + out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var; cache_mean = ema_mean; cache_inv_var = ema_var + [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu) } backward = function(matrix[double] dout, @@ -152,24 +119,9 @@ backward = function(matrix[double] dout, * - dbeta: Gradient wrt `b`, of shape (C, 1). * */ - N = nrow(X) - mean = cache_mean - centered = bias_add(X, -mean) # shape (N, C*Hin*Win) - norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) # Compute gradients during training - dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1) - dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1) - dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win) - dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm, - C, Hin, Win) # shape (C, 1) - dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win) - dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win) - dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet - dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1) - dX_norm_branch = bias_multiply(dnorm, cache_inv_var) - dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean) - dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar) - dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win) + dX = X; dgamma = gamma; dbeta = gamma; + [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, cache_mean, cache_inv_var) } init = function(int C) http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/batch_norm2d_old.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/batch_norm2d_old.dml b/scripts/nn/layers/batch_norm2d_old.dml new file mode 100644 index 0000000..2aba2e6 --- /dev/null +++ b/scripts/nn/layers/batch_norm2d_old.dml @@ -0,0 +1,200 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * 2D (Spatial) Batch Normalization layer. + */ +source("nn/util.dml") as util + +forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, + int C, int Hin, int Win, string mode, + matrix[double] ema_mean, matrix[double] ema_var, + double mu, double epsilon) + return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd, + matrix[double] cache_mean, matrix[double] cache_inv_var) { + /* + * Computes the forward pass for a 2D (spatial) batch normalization + * layer. The input data has N examples, each represented as a 3D + * volume unrolled into a single vector. + * + * A spatial batch normalization layer uses the per-channel sample + * mean and per-channel uncorrected sample variance during training + * to normalize each channel of the input data. Additionally, it + * introduces learnable parameters (gamma, beta) to control the + * amount of normalization. + * + * `y = ((x-mean) / sqrt(var+eps)) * gamma + beta` + * + * This implementation maintains exponential moving averages of the + * mean and variance during training for use during testing. + * + * Reference: + * - Batch Normalization: Accelerating Deep Network Training by + * Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015 + * - https://arxiv.org/abs/1502.03167 + * + * Inputs: + * - X: Inputs, of shape (N, C*Hin*Win). + * - gamma: Scale parameters, of shape (C, 1). + * - beta: Shift parameters, of shape (C, 1). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - mode: 'train' or 'test' to indicate if the model is currently + * being trained or tested. During training, the current batch + * mean and variance will be used to normalize the inputs, while + * during testing, the exponential average of the mean and + * variance over all previous batches will be used. + * - ema_mean: Exponential moving average of the mean, of + * shape (C, 1). + * - ema_var: Exponential moving average of the variance, of + * shape (C, 1). + * - mu: Momentum value for moving averages. + * Typical values are in the range of [0.9, 0.999]. + * - epsilon: Smoothing term to avoid divide by zero errors. + * Typical values are in the range of [1e-5, 1e-3]. + * + * Outputs: + * - out: Outputs, of shape (N, C*Hin*Win). + * - ema_mean_upd: Updated exponential moving average of the mean, + * of shape (C, 1). + * - ema_var_upd: Updated exponential moving average of the variance, + * of shape (C, 1). + * - cache_mean: Cache of the batch mean, of shape (C, 1). + * Note: This is used for performance during training. + * - cache_inv_var: Cache of the inverse variance, of shape (C, 1). + * Note: This is used for performance during training. + */ + N = nrow(X) + + if (mode == 'train') { + # Compute channel-wise mean and variance + # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion. + # - mean of total group is mean of subgroup means + # - variance is the mean of the subgroup variances + the variance of the subgroup means + subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) + subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances + mean = rowMeans(subgrp_means) # shape (C, 1) + var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1) + # Update moving averages + ema_mean_upd = mu*ema_mean + (1-mu)*mean + ema_var_upd = mu*ema_var + (1-mu)*var + } + else { + # Use moving averages of mean and variance during testing + mean = ema_mean + var = ema_var + ema_mean_upd = ema_mean + ema_var_upd = ema_var + } + + # Save variable for backward pass + cache_mean = mean + cache_inv_var = 1/sqrt(var+epsilon) + + # Normalize, shift, and scale + # norm = (X-mean)*(var+epsilon)^(-1/2) + # = (X-mean) / sqrt(var+epsilon) + centered = bias_add(X, -mean) # shape (N, C*Hin*Win) + norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) + # out = norm*gamma + beta + scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win) + out = bias_add(scaled, beta) # shape (N, C*Hin*Win) +} + +backward = function(matrix[double] dout, + matrix[double] cache_mean, matrix[double] cache_inv_var, + matrix[double] X, matrix[double] gamma, + int C, int Hin, int Win, double epsilon) + return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) { + /* + * Computes the backward pass for a 2D (spatial) batch normalization + * layer. + * + * Inputs: + * - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win). + * - cache_mean: Cache of the batch mean from the forward pass, of + * shape (C, 1). Note: This is used for performance during + * training. + * - cache_inv_var: Cache of the inverse variance from the forward pass, + * of shape (C, 1). Note: This is used for performance during + * training. + * - X: Input data matrix to the forward pass, of + * shape (N, C*Hin*Win). + * - gamma: Scale parameters, of shape (C, 1). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - epsilon: Smoothing term to avoid divide by zero errors. + * Typical values are in the range of [1e-5, 1e-3]. + * + * Outputs: + * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win). + * - dgamma: Gradient wrt `W`, of shape (C, 1). + * - dbeta: Gradient wrt `b`, of shape (C, 1). + * + */ + N = nrow(X) + mean = cache_mean + centered = bias_add(X, -mean) # shape (N, C*Hin*Win) + norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) + # Compute gradients during training + dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1) + dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1) + dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win) + dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm, + C, Hin, Win) # shape (C, 1) + dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win) + dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win) + dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet + dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1) + dX_norm_branch = bias_multiply(dnorm, cache_inv_var) + dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean) + dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar) + dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win) +} + +init = function(int C) + return (matrix[double] gamma, matrix[double] beta, + matrix[double] ema_mean, matrix[double] ema_var) { + /* + * Initialize the parameters of this layer. + * + * Note: This is just a convenience function, and parameters + * may be initialized manually if needed. + * + * Inputs: + * - C: Number of input channels (dimensionality of input depth). + * + * Outputs: + * - gamma: Scale parameters, of shape (C, 1). + * - beta: Shift parameters, of shape (C, 1). + * - ema_mean: Exponential moving average of the mean, of + * shape (C, 1). + * - ema_var: Exponential moving average of the variance, of + * shape (C, 1). + */ + gamma = matrix(1, rows=C, cols=1) + beta = matrix(0, rows=C, cols=1) + ema_mean = matrix(0, rows=C, cols=1) + ema_var = matrix(1, rows=C, cols=1) +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/scripts/nn/layers/lstm_staging.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/lstm_staging.dml b/scripts/nn/layers/lstm_staging.dml new file mode 100644 index 0000000..d0949d9 --- /dev/null +++ b/scripts/nn/layers/lstm_staging.dml @@ -0,0 +1,216 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * LSTM layer. + */ +source("nn/layers/sigmoid.dml") as sigmoid +source("nn/layers/tanh.dml") as tanh + +forward = function(matrix[double] X, matrix[double] W, matrix[double] b, + boolean return_sequences, matrix[double] out0, matrix[double] c0) + return (matrix[double] out, matrix[double] c) { + /* + * Computes the forward pass for an LSTM layer with M neurons. + * The input data has N sequences of T examples, each with D features. + * + * In an LSTM, an internal cell state is maintained, additive + * interactions operate over the cell state at each timestep, and + * some amount of this cell state is exposed as output at each + * timestep. Additionally, the output of the previous timestep is fed + * back in as an additional input at the current timestep. + * + * Reference: + * - Long Short-Term Memory, Hochreiter, 1997 + * - http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + * + * Inputs: + * - X: Inputs, of shape (N, T*D). + * - W: Weights, of shape (D+M, 4M). + * - b: Biases, of shape (1, 4M). + * - return_sequences: Whether to return `out` at all timesteps, + * or just for the final timestep. + * - out0: Outputs from previous timestep, of shape (N, M). + * Note: This is *optional* and could just be an empty matrix. + * - c0: Initial cell state, of shape (N, M). + * Note: This is *optional* and could just be an empty matrix. + * + * Outputs: + * - out: If `return_sequences` is True, outputs for all timesteps, + * of shape (N, T*M). Else, outputs for the final timestep, of + * shape (N, M). + * - c: Cell state for final timestep, of shape (N, M). + * - reserveSpace: reserveSpace to be passed to output (row-vector whose size is determined at runtime). + */ + [out, c, reserveSpace] = lstm(X, W, b, out0, c0, return_sequences) +} + +# TODO: +backward = function(matrix[double] dout, matrix[double] dc, + matrix[double] X, matrix[double] W, matrix[double] b, int T, int D, + boolean given_sequences, matrix[double] out0, matrix[double] c0, + matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog) + return (matrix[double] dX, matrix[double] dW, matrix[double] db, + matrix[double] dout0, matrix[double] dc0) { + /* + * Computes the backward pass for an LSTM layer with M neurons. + * + * Inputs: + * - dout: Gradient wrt `out`. If `given_sequences` is `True`, + * contains gradients on outputs for all timesteps, of + * shape (N, T*M). Else, contains the gradient on the output + * for the final timestep, of shape (N, M). + * - dc: Gradient wrt `c` (from later in time), of shape (N, M). + * This would come from later in time if the cell state was used + * downstream as the initial cell state for another LSTM layer. + * Typically, this would be used when a sequence was cut at + * timestep `T` and then continued in the next batch. If `c` + * was not used downstream, then `dc` would be an empty matrix. + * - X: Inputs, of shape (N, T*D). + * - W: Weights, of shape (D+M, 4M). + * - b: Biases, of shape (1, 4M). + * - T: Length of example sequences (number of timesteps). + * - D: Dimensionality of the input features. + * - given_sequences: Whether `dout` is for all timesteps, + * or just for the final timestep. This is based on whether + * `return_sequences` was true in the forward pass. + * - out0: Outputs from previous timestep, of shape (N, M). + * Note: This is *optional* and could just be an empty matrix. + * - c0: Initial cell state, of shape (N, M). + * Note: This is *optional* and could just be an empty matrix. + * - cache_out: Cache of outputs, of shape (T, N*M). + * Note: This is used for performance during training. + * - cache_c: Cache of cell state, of shape (T, N*M). + * Note: This is used for performance during training. + * - cache_ifog: Cache of intermediate values, of shape (T, N*4*M). + * Note: This is used for performance during training. + * + * Outputs: + * - dX: Gradient wrt `X`, of shape (N, T*D). + * - dW: Gradient wrt `W`, of shape (D+M, 4M). + * - db: Gradient wrt `b`, of shape (1, 4M). + * - dout0: Gradient wrt `out0`, of shape (N, M). + * - dc0: Gradient wrt `c0`, of shape (N, M). + */ + N = nrow(X) + M = as.integer(ncol(W)/4) + N1 = nrow(out0) + if(N != N1) { + # Allow for smaller out0 for last batch + # out0 = out0[1:N,] + # c0 = c0[1:N,] + stop("Unsupported operation: The batch size of previous iteration " + N1 + " is different than the batch size of current iteration " + N) + } + dX = matrix(0, rows=N, cols=T*D) + dW = matrix(0, rows=D+M, cols=4*M) + db = matrix(0, rows=1, cols=4*M) + dout0 = matrix(0, rows=N, cols=M) + dc0 = matrix(0, rows=N, cols=M) + dct = dc + if (!given_sequences) { + # only given dout for output at final timestep, so prepend empty douts for all other timesteps + dout = cbind(matrix(0, rows=N, cols=(T-1)*M), dout) # shape (N, T*M) + } + + t = T + for (iter in 1:T) { # each timestep in reverse order + X_t = X[,(t-1)*D+1:t*D] # shape (N, D) + dout_t = dout[,(t-1)*M+1:t*M] # shape (N, M) + out_t = matrix(cache_out[t,], rows=N, cols=M) # shape (N, M) + ct = matrix(cache_c[t,], rows=N, cols=M) # shape (N, M) + if (t == 1) { + out_prev = out0 # shape (N, M) + c_prev = c0 # shape (N, M) + } + else { + out_prev = matrix(cache_out[t-1,], rows=N, cols=M) # shape (N, M) + c_prev = matrix(cache_c[t-1,], rows=N, cols=M) # shape (N, M) + } + input = cbind(X_t, out_prev) # shape (N, D+M) + ifog = matrix(cache_ifog[t,], rows=N, cols=4*M) + i = ifog[,1:M] # input gate, shape (N, M) + f = ifog[,M+1:2*M] # forget gate, shape (N, M) + o = ifog[,2*M+1:3*M] # output gate, shape (N, M) + g = ifog[,3*M+1:4*M] # g gate, shape (N, M) + + dct = dct + o*tanh::backward(dout_t, ct) # shape (N, M) + do = tanh::forward(ct) * dout_t # output gate, shape (N, M) + df = c_prev * dct # forget gate, shape (N, M) + dc_prev = f * dct # shape (N, M) + di = g * dct # input gate, shape (N, M) + dg = i * dct # g gate, shape (N, M) + + di_raw = i * (1-i) * di + df_raw = f * (1-f) * df + do_raw = o * (1-o) * do + dg_raw = (1-g^2) * dg + difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw) # shape (N, 4M) + + dW = dW + t(input) %*% difog_raw # shape (D+M, 4M) + db = db + colSums(difog_raw) # shape (1, 4M) + dinput = difog_raw %*% t(W) # shape (N, D+M) + dX[,(t-1)*D+1:t*D] = dinput[,1:D] + dout_prev = dinput[,D+1:D+M] # shape (N, M) + if (t == 1) { + dout0 = dout_prev # shape (N, M) + dc0 = dc_prev # shape (N, M) + } + else { + dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev # shape (N, M) + dct = dc_prev # shape (N, M) + } + t = t - 1 + } +} + +init = function(int N, int D, int M) + return (matrix[double] W, matrix[double] b, matrix[double] out0, matrix[double] c0) { + /* + * Initialize the parameters of this layer. + * + * Note: This is just a convenience function, and parameters + * may be initialized manually if needed. + * + * We use the Glorot uniform heuristic which limits the magnification + * of inputs/gradients during forward/backward passes by scaling + * uniform weights by a factor of sqrt(6/(fan_in + fan_out)). + * - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf + * + * Inputs: + * - N: Number of examples in batch. + * - D: Dimensionality of the input features (number of features). + * - M: Number of neurons in this layer. + * + * Outputs: + * - W: Weights, of shape (D+M, 4M). + * - b: Biases, of shape (1, 4M). + * - out0: Empty previous timestep output matrix, of shape (N, M). + * - c0: Empty initial cell state matrix, of shape (N, M). + */ + fan_in = D+M + fan_out = 4*M + scale = sqrt(6/(fan_in+fan_out)) + W = rand(rows=D+M, cols=4*M, min=-scale, max=scale, pdf="uniform") + b = matrix(0, rows=1, cols=4*M) + out0 = matrix(0, rows=N, cols=M) + c0 = matrix(0, rows=N, cols=M) +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 55ebeaf..cc2d531 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -1962,6 +1962,87 @@ extern "C" __global__ void matrix_sigmoid_f(float *A, float *C, matrix_sigmoid(A, C, size); } +template <typename T> +__device__ void prepare_lstm_input(T* smlInput, T* cudnnInput, int N, int D, int TD, int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < size) { + int n = index / TD; + int td = index % TD; + int t = td / D; + int d = td % D; + cudnnInput[t*N*D + n*D + d] = smlInput[index]; + } +} + + +extern "C" __global__ void prepare_lstm_input_d(double* smlInput, double* cudnnInput, int N, int D, int TD, int size) { + prepare_lstm_input(smlInput, cudnnInput, N, D, TD, size); +} + +extern "C" __global__ void prepare_lstm_input_f(float* smlInput, float* cudnnInput, int N, int D, int TD, int size) { + prepare_lstm_input(smlInput, cudnnInput, N, D, TD, size); +} + +__device__ int swap_co(int offset) { + return (offset < 2) ? offset : (offset == 2 ? 3 : 2); +} + +template <typename T> +__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, int D, int M) { + int DM = D*M; int MM = M*M; int DM4 = DM*4; + int M4 = M*4; + int index = blockIdx.x * blockDim.x + threadIdx.x; + // input: cbind(X_t, out_prev) => [N, D+M], weight: [D+M, 4M] + // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnGetRNNLinLayerMatrixParams states that + // Elements in each weight matrix are arranged in the row-major order, but the column-major format works !! + // CuDNN gate order: i, f, c, o + // CuDNN weight order: w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o + // SystemML weight order: i, f, o, c; TF weight order: i, c, f, o + // SystemML performs (X_t %*% W + out_prev %*% R) => [N, 4*M] + + // bias layout: bi bf bc bo 0 0 0 0 + // where W: [DxM], R: [MxM] and b: [1x1] + + // Maximum (D+M+2)*M4 threads + int srcIndex = -1; int destIndex; + if(index < DM4) { + // Fill w_i, w_f, w_c and w_o + int localIndex = index%DM; + int smlRowIndex = localIndex/M; + int smlColIndex = swap_co(index/(DM))*M + localIndex%M; + // Convert index to column-major where index = (index/(DM))*DM + (localIndex/M)*M + localIndex%M + destIndex = (index/(DM))*DM + (localIndex%M)*D + localIndex/M; + srcIndex = smlRowIndex*M4+smlColIndex; + } + else if(index < (D+M)*M4) { + // Fill r_i, r_f, r_c and r_o + int tmpIndex = index-DM4; + int localIndex = tmpIndex % MM; + int smlRowIndex = D + (localIndex / M); + int smlColIndex = swap_co(tmpIndex/(MM))*M + localIndex%M; + // Convert index to column-major where index = DM4 + (tmpIndex/(MM))*MM + (localIndex/M)*M + localIndex%M + destIndex = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M; + srcIndex = smlRowIndex*M4+smlColIndex; + } + else if(index < (D+M+1)*M4) { + // Fill bias + int tmpIndex = index - (D+M)*M4; + int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M; + cudnnWeight[index] = smlBias[smlColIndex]; + } + // __syncthreads(); + if(srcIndex != -1) + cudnnWeight[destIndex] = smlWeight[srcIndex]; +} + +extern "C" __global__ void prepare_lstm_weight_d(double* smlWeight, double* smlBias, double* cudnnWeight, int D, int M) { + prepare_lstm_weight(smlWeight, smlBias, cudnnWeight, D, M); +} + +extern "C" __global__ void prepare_lstm_weight_f(float* smlWeight, float* smlBias, float* cudnnWeight, int D, int M) { + prepare_lstm_weight(smlWeight, smlBias, cudnnWeight, D, M); +} + // We can later fold it in our reduce method template <typename T> __device__ void compute_nnz( @@ -2058,3 +2139,26 @@ extern "C" __global__ void compute_nnz_d(double *g_idata, double *g_odata, unsig extern "C" __global__ void compute_nnz_f(float *g_idata, float *g_odata, unsigned int n) { compute_nnz(g_idata, g_odata, n); } + +template <typename T> +__device__ void prepare_lstm_output(T* smlInput, T* cudnnInput, int N, int T1, int M, int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < size) { + int TM = T1*M; + int NT = T1*N; + int n = index / TM; + int tm = index % TM; + int t = tm / M; + int m = tm % M; + smlInput[index] = cudnnInput[t*N*M + n*M + m]; + } +} + + +extern "C" __global__ void prepare_lstm_output_d(double* smlInput, double* cudnnInput, int N, int T, int M, int size) { + prepare_lstm_output(smlInput, cudnnInput, N, T, M, size); +} + +extern "C" __global__ void prepare_lstm_output_f(float* smlInput, float* cudnnInput, int N, int T, int M, int size) { + prepare_lstm_output(smlInput, cudnnInput, N, T, M, size); +} http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index 1865e18..ed1a100 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -11689,6 +11689,328 @@ BB97_5: ret; } + // .globl prepare_lstm_input_d +.visible .entry prepare_lstm_input_d( + .param .u64 prepare_lstm_input_d_param_0, + .param .u64 prepare_lstm_input_d_param_1, + .param .u32 prepare_lstm_input_d_param_2, + .param .u32 prepare_lstm_input_d_param_3, + .param .u32 prepare_lstm_input_d_param_4, + .param .u32 prepare_lstm_input_d_param_5 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<15>; + .reg .f64 %fd<2>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_input_d_param_0]; + ld.param.u64 %rd2, [prepare_lstm_input_d_param_1]; + ld.param.u32 %r2, [prepare_lstm_input_d_param_2]; + ld.param.u32 %r3, [prepare_lstm_input_d_param_3]; + ld.param.u32 %r4, [prepare_lstm_input_d_param_4]; + ld.param.u32 %r5, [prepare_lstm_input_d_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB98_2; + + cvta.to.global.u64 %rd3, %rd1; + rem.s32 %r9, %r1, %r4; + div.s32 %r10, %r9, %r3; + rem.s32 %r11, %r9, %r3; + mul.wide.s32 %rd4, %r1, 8; + add.s64 %rd5, %rd3, %rd4; + ld.global.f64 %fd1, [%rd5]; + div.s32 %r12, %r1, %r4; + mad.lo.s32 %r13, %r10, %r2, %r12; + mad.lo.s32 %r14, %r13, %r3, %r11; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.s32 %rd7, %r14, 8; + add.s64 %rd8, %rd6, %rd7; + st.global.f64 [%rd8], %fd1; + +BB98_2: + ret; +} + + // .globl prepare_lstm_input_f +.visible .entry prepare_lstm_input_f( + .param .u64 prepare_lstm_input_f_param_0, + .param .u64 prepare_lstm_input_f_param_1, + .param .u32 prepare_lstm_input_f_param_2, + .param .u32 prepare_lstm_input_f_param_3, + .param .u32 prepare_lstm_input_f_param_4, + .param .u32 prepare_lstm_input_f_param_5 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<2>; + .reg .b32 %r<15>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_input_f_param_0]; + ld.param.u64 %rd2, [prepare_lstm_input_f_param_1]; + ld.param.u32 %r2, [prepare_lstm_input_f_param_2]; + ld.param.u32 %r3, [prepare_lstm_input_f_param_3]; + ld.param.u32 %r4, [prepare_lstm_input_f_param_4]; + ld.param.u32 %r5, [prepare_lstm_input_f_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB99_2; + + cvta.to.global.u64 %rd3, %rd1; + rem.s32 %r9, %r1, %r4; + div.s32 %r10, %r9, %r3; + rem.s32 %r11, %r9, %r3; + mul.wide.s32 %rd4, %r1, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f1, [%rd5]; + div.s32 %r12, %r1, %r4; + mad.lo.s32 %r13, %r10, %r2, %r12; + mad.lo.s32 %r14, %r13, %r3, %r11; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.s32 %rd7, %r14, 4; + add.s64 %rd8, %rd6, %rd7; + st.global.f32 [%rd8], %f1; + +BB99_2: + ret; +} + + // .globl prepare_lstm_weight_d +.visible .entry prepare_lstm_weight_d( + .param .u64 prepare_lstm_weight_d_param_0, + .param .u64 prepare_lstm_weight_d_param_1, + .param .u64 prepare_lstm_weight_d_param_2, + .param .u32 prepare_lstm_weight_d_param_3, + .param .u32 prepare_lstm_weight_d_param_4 +) +{ + .reg .pred %p<11>; + .reg .b32 %r<53>; + .reg .f64 %fd<3>; + .reg .b64 %rd<15>; + + + ld.param.u64 %rd2, [prepare_lstm_weight_d_param_0]; + ld.param.u64 %rd3, [prepare_lstm_weight_d_param_1]; + ld.param.u64 %rd4, [prepare_lstm_weight_d_param_2]; + ld.param.u32 %r13, [prepare_lstm_weight_d_param_3]; + ld.param.u32 %r14, [prepare_lstm_weight_d_param_4]; + cvta.to.global.u64 %rd1, %rd4; + mul.lo.s32 %r1, %r14, %r13; + shl.b32 %r2, %r1, 2; + shl.b32 %r3, %r14, 2; + mov.u32 %r15, %ntid.x; + mov.u32 %r16, %ctaid.x; + mov.u32 %r17, %tid.x; + mad.lo.s32 %r4, %r15, %r16, %r17; + setp.lt.s32 %p1, %r4, %r2; + @%p1 bra BB100_5; + bra.uni BB100_1; + +BB100_5: + rem.s32 %r42, %r4, %r1; + div.s32 %r43, %r42, %r14; + div.s32 %r44, %r4, %r1; + setp.lt.s32 %p8, %r44, 2; + setp.eq.s32 %p9, %r44, 2; + selp.b32 %r45, 3, 2, %p9; + selp.b32 %r46, %r44, %r45, %p8; + rem.s32 %r47, %r42, %r14; + sub.s32 %r48, %r4, %r42; + add.s32 %r49, %r48, %r43; + mad.lo.s32 %r52, %r47, %r13, %r49; + mad.lo.s32 %r50, %r43, %r3, %r47; + mad.lo.s32 %r51, %r46, %r14, %r50; + bra.uni BB100_6; + +BB100_1: + add.s32 %r5, %r14, %r13; + mul.lo.s32 %r6, %r5, %r3; + setp.lt.s32 %p2, %r4, %r6; + @%p2 bra BB100_4; + bra.uni BB100_2; + +BB100_4: + mul.lo.s32 %r30, %r14, %r14; + sub.s32 %r31, %r4, %r2; + rem.s32 %r32, %r31, %r30; + div.s32 %r33, %r32, %r14; + add.s32 %r34, %r33, %r13; + div.s32 %r35, %r31, %r30; + setp.lt.s32 %p6, %r35, 2; + setp.eq.s32 %p7, %r35, 2; + selp.b32 %r36, 3, 2, %p7; + selp.b32 %r37, %r35, %r36, %p6; + rem.s32 %r38, %r32, %r14; + sub.s32 %r39, %r4, %r32; + add.s32 %r40, %r39, %r33; + mad.lo.s32 %r52, %r38, %r14, %r40; + mad.lo.s32 %r41, %r34, %r3, %r38; + mad.lo.s32 %r51, %r37, %r14, %r41; + bra.uni BB100_6; + +BB100_2: + add.s32 %r20, %r5, 1; + mul.lo.s32 %r21, %r20, %r3; + mov.u32 %r51, -1; + setp.ge.s32 %p3, %r4, %r21; + @%p3 bra BB100_6; + + cvta.to.global.u64 %rd5, %rd3; + sub.s32 %r24, %r4, %r6; + div.s32 %r25, %r24, %r14; + setp.lt.s32 %p4, %r25, 2; + setp.eq.s32 %p5, %r25, 2; + selp.b32 %r26, 3, 2, %p5; + selp.b32 %r27, %r25, %r26, %p4; + rem.s32 %r28, %r24, %r14; + mad.lo.s32 %r29, %r27, %r14, %r28; + mul.wide.s32 %rd6, %r29, 8; + add.s64 %rd7, %rd5, %rd6; + ld.global.f64 %fd1, [%rd7]; + mul.wide.s32 %rd8, %r4, 8; + add.s64 %rd9, %rd1, %rd8; + st.global.f64 [%rd9], %fd1; + +BB100_6: + setp.eq.s32 %p10, %r51, -1; + @%p10 bra BB100_8; + + cvta.to.global.u64 %rd10, %rd2; + mul.wide.s32 %rd11, %r51, 8; + add.s64 %rd12, %rd10, %rd11; + ld.global.f64 %fd2, [%rd12]; + mul.wide.s32 %rd13, %r52, 8; + add.s64 %rd14, %rd1, %rd13; + st.global.f64 [%rd14], %fd2; + +BB100_8: + ret; +} + + // .globl prepare_lstm_weight_f +.visible .entry prepare_lstm_weight_f( + .param .u64 prepare_lstm_weight_f_param_0, + .param .u64 prepare_lstm_weight_f_param_1, + .param .u64 prepare_lstm_weight_f_param_2, + .param .u32 prepare_lstm_weight_f_param_3, + .param .u32 prepare_lstm_weight_f_param_4 +) +{ + .reg .pred %p<11>; + .reg .f32 %f<3>; + .reg .b32 %r<53>; + .reg .b64 %rd<15>; + + + ld.param.u64 %rd2, [prepare_lstm_weight_f_param_0]; + ld.param.u64 %rd3, [prepare_lstm_weight_f_param_1]; + ld.param.u64 %rd4, [prepare_lstm_weight_f_param_2]; + ld.param.u32 %r13, [prepare_lstm_weight_f_param_3]; + ld.param.u32 %r14, [prepare_lstm_weight_f_param_4]; + cvta.to.global.u64 %rd1, %rd4; + mul.lo.s32 %r1, %r14, %r13; + shl.b32 %r2, %r1, 2; + shl.b32 %r3, %r14, 2; + mov.u32 %r15, %ntid.x; + mov.u32 %r16, %ctaid.x; + mov.u32 %r17, %tid.x; + mad.lo.s32 %r4, %r15, %r16, %r17; + setp.lt.s32 %p1, %r4, %r2; + @%p1 bra BB101_5; + bra.uni BB101_1; + +BB101_5: + rem.s32 %r42, %r4, %r1; + div.s32 %r43, %r42, %r14; + div.s32 %r44, %r4, %r1; + setp.lt.s32 %p8, %r44, 2; + setp.eq.s32 %p9, %r44, 2; + selp.b32 %r45, 3, 2, %p9; + selp.b32 %r46, %r44, %r45, %p8; + rem.s32 %r47, %r42, %r14; + sub.s32 %r48, %r4, %r42; + add.s32 %r49, %r48, %r43; + mad.lo.s32 %r52, %r47, %r13, %r49; + mad.lo.s32 %r50, %r43, %r3, %r47; + mad.lo.s32 %r51, %r46, %r14, %r50; + bra.uni BB101_6; + +BB101_1: + add.s32 %r5, %r14, %r13; + mul.lo.s32 %r6, %r5, %r3; + setp.lt.s32 %p2, %r4, %r6; + @%p2 bra BB101_4; + bra.uni BB101_2; + +BB101_4: + mul.lo.s32 %r30, %r14, %r14; + sub.s32 %r31, %r4, %r2; + rem.s32 %r32, %r31, %r30; + div.s32 %r33, %r32, %r14; + add.s32 %r34, %r33, %r13; + div.s32 %r35, %r31, %r30; + setp.lt.s32 %p6, %r35, 2; + setp.eq.s32 %p7, %r35, 2; + selp.b32 %r36, 3, 2, %p7; + selp.b32 %r37, %r35, %r36, %p6; + rem.s32 %r38, %r32, %r14; + sub.s32 %r39, %r4, %r32; + add.s32 %r40, %r39, %r33; + mad.lo.s32 %r52, %r38, %r14, %r40; + mad.lo.s32 %r41, %r34, %r3, %r38; + mad.lo.s32 %r51, %r37, %r14, %r41; + bra.uni BB101_6; + +BB101_2: + add.s32 %r20, %r5, 1; + mul.lo.s32 %r21, %r20, %r3; + mov.u32 %r51, -1; + setp.ge.s32 %p3, %r4, %r21; + @%p3 bra BB101_6; + + cvta.to.global.u64 %rd5, %rd3; + sub.s32 %r24, %r4, %r6; + div.s32 %r25, %r24, %r14; + setp.lt.s32 %p4, %r25, 2; + setp.eq.s32 %p5, %r25, 2; + selp.b32 %r26, 3, 2, %p5; + selp.b32 %r27, %r25, %r26, %p4; + rem.s32 %r28, %r24, %r14; + mad.lo.s32 %r29, %r27, %r14, %r28; + mul.wide.s32 %rd6, %r29, 4; + add.s64 %rd7, %rd5, %rd6; + ld.global.f32 %f1, [%rd7]; + mul.wide.s32 %rd8, %r4, 4; + add.s64 %rd9, %rd1, %rd8; + st.global.f32 [%rd9], %f1; + +BB101_6: + setp.eq.s32 %p10, %r51, -1; + @%p10 bra BB101_8; + + cvta.to.global.u64 %rd10, %rd2; + mul.wide.s32 %rd11, %r51, 4; + add.s64 %rd12, %rd10, %rd11; + ld.global.f32 %f2, [%rd12]; + mul.wide.s32 %rd13, %r52, 4; + add.s64 %rd14, %rd1, %rd13; + st.global.f32 [%rd14], %f2; + +BB101_8: + ret; +} + // .globl compute_nnz_d .visible .entry compute_nnz_d( .param .u64 compute_nnz_d_param_0, @@ -11712,9 +12034,9 @@ BB97_5: mad.lo.s32 %r35, %r9, %r10, %r7; mov.f64 %fd46, 0d0000000000000000; setp.ge.u32 %p1, %r35, %r6; - @%p1 bra BB98_4; + @%p1 bra BB102_4; -BB98_1: +BB102_1: cvta.to.global.u64 %rd3, %rd1; mul.wide.u32 %rd4, %r35, 8; add.s64 %rd5, %rd3, %rd4; @@ -11724,7 +12046,7 @@ BB98_1: add.f64 %fd46, %fd46, %fd31; add.s32 %r3, %r35, %r10; setp.ge.u32 %p3, %r3, %r6; - @%p3 bra BB98_3; + @%p3 bra BB102_3; mul.wide.u32 %rd7, %r3, 8; add.s64 %rd8, %rd3, %rd7; @@ -11733,128 +12055,128 @@ BB98_1: selp.f64 %fd33, 0d3FF0000000000000, 0d0000000000000000, %p4; add.f64 %fd46, %fd46, %fd33; -BB98_3: +BB102_3: shl.b32 %r13, %r10, 1; mov.u32 %r14, %nctaid.x; mad.lo.s32 %r35, %r13, %r14, %r35; setp.lt.u32 %p5, %r35, %r6; - @%p5 bra BB98_1; + @%p5 bra BB102_1; -BB98_4: +BB102_4: shl.b32 %r16, %r7, 3; mov.u32 %r17, my_sdata; add.s32 %r5, %r17, %r16; st.shared.f64 [%r5], %fd46; bar.sync 0; setp.lt.u32 %p6, %r10, 1024; - @%p6 bra BB98_8; + @%p6 bra BB102_8; setp.gt.u32 %p7, %r7, 511; - @%p7 bra BB98_7; + @%p7 bra BB102_7; ld.shared.f64 %fd34, [%r5+4096]; add.f64 %fd46, %fd46, %fd34; st.shared.f64 [%r5], %fd46; -BB98_7: +BB102_7: bar.sync 0; -BB98_8: +BB102_8: setp.lt.u32 %p8, %r10, 512; - @%p8 bra BB98_12; + @%p8 bra BB102_12; setp.gt.u32 %p9, %r7, 255; - @%p9 bra BB98_11; + @%p9 bra BB102_11; ld.shared.f64 %fd35, [%r5+2048]; add.f64 %fd46, %fd46, %fd35; st.shared.f64 [%r5], %fd46; -BB98_11: +BB102_11: bar.sync 0; -BB98_12: +BB102_12: setp.lt.u32 %p10, %r10, 256; - @%p10 bra BB98_16; + @%p10 bra BB102_16; setp.gt.u32 %p11, %r7, 127; - @%p11 bra BB98_15; + @%p11 bra BB102_15; ld.shared.f64 %fd36, [%r5+1024]; add.f64 %fd46, %fd46, %fd36; st.shared.f64 [%r5], %fd46; -BB98_15: +BB102_15: bar.sync 0; -BB98_16: +BB102_16: setp.lt.u32 %p12, %r10, 128; - @%p12 bra BB98_20; + @%p12 bra BB102_20; setp.gt.u32 %p13, %r7, 63; - @%p13 bra BB98_19; + @%p13 bra BB102_19; ld.shared.f64 %fd37, [%r5+512]; add.f64 %fd46, %fd46, %fd37; st.shared.f64 [%r5], %fd46; -BB98_19: +BB102_19: bar.sync 0; -BB98_20: +BB102_20: setp.gt.u32 %p14, %r7, 31; - @%p14 bra BB98_33; + @%p14 bra BB102_33; setp.lt.u32 %p15, %r10, 64; - @%p15 bra BB98_23; + @%p15 bra BB102_23; ld.volatile.shared.f64 %fd38, [%r5+256]; add.f64 %fd46, %fd46, %fd38; st.volatile.shared.f64 [%r5], %fd46; -BB98_23: +BB102_23: setp.lt.u32 %p16, %r10, 32; - @%p16 bra BB98_25; + @%p16 bra BB102_25; ld.volatile.shared.f64 %fd39, [%r5+128]; add.f64 %fd46, %fd46, %fd39; st.volatile.shared.f64 [%r5], %fd46; -BB98_25: +BB102_25: setp.lt.u32 %p17, %r10, 16; - @%p17 bra BB98_27; + @%p17 bra BB102_27; ld.volatile.shared.f64 %fd40, [%r5+64]; add.f64 %fd46, %fd46, %fd40; st.volatile.shared.f64 [%r5], %fd46; -BB98_27: +BB102_27: setp.lt.u32 %p18, %r10, 8; - @%p18 bra BB98_29; + @%p18 bra BB102_29; ld.volatile.shared.f64 %fd41, [%r5+32]; add.f64 %fd46, %fd46, %fd41; st.volatile.shared.f64 [%r5], %fd46; -BB98_29: +BB102_29: setp.lt.u32 %p19, %r10, 4; - @%p19 bra BB98_31; + @%p19 bra BB102_31; ld.volatile.shared.f64 %fd42, [%r5+16]; add.f64 %fd46, %fd46, %fd42; st.volatile.shared.f64 [%r5], %fd46; -BB98_31: +BB102_31: setp.lt.u32 %p20, %r10, 2; - @%p20 bra BB98_33; + @%p20 bra BB102_33; ld.volatile.shared.f64 %fd43, [%r5+8]; add.f64 %fd44, %fd46, %fd43; st.volatile.shared.f64 [%r5], %fd44; -BB98_33: +BB102_33: setp.ne.s32 %p21, %r7, 0; - @%p21 bra BB98_35; + @%p21 bra BB102_35; ld.shared.f64 %fd45, [my_sdata]; cvta.to.global.u64 %rd9, %rd2; @@ -11862,7 +12184,7 @@ BB98_33: add.s64 %rd11, %rd9, %rd10; st.global.f64 [%rd11], %fd45; -BB98_35: +BB102_35: ret; } @@ -11889,9 +12211,9 @@ BB98_35: mad.lo.s32 %r35, %r9, %r10, %r7; mov.f32 %f46, 0f00000000; setp.ge.u32 %p1, %r35, %r6; - @%p1 bra BB99_4; + @%p1 bra BB103_4; -BB99_1: +BB103_1: cvta.to.global.u64 %rd3, %rd1; mul.wide.u32 %rd4, %r35, 4; add.s64 %rd5, %rd3, %rd4; @@ -11901,7 +12223,7 @@ BB99_1: add.f32 %f46, %f46, %f31; add.s32 %r3, %r35, %r10; setp.ge.u32 %p3, %r3, %r6; - @%p3 bra BB99_3; + @%p3 bra BB103_3; mul.wide.u32 %rd7, %r3, 4; add.s64 %rd8, %rd3, %rd7; @@ -11910,128 +12232,128 @@ BB99_1: selp.f32 %f33, 0f3F800000, 0f00000000, %p4; add.f32 %f46, %f46, %f33; -BB99_3: +BB103_3: shl.b32 %r13, %r10, 1; mov.u32 %r14, %nctaid.x; mad.lo.s32 %r35, %r13, %r14, %r35; setp.lt.u32 %p5, %r35, %r6; - @%p5 bra BB99_1; + @%p5 bra BB103_1; -BB99_4: +BB103_4: shl.b32 %r16, %r7, 2; mov.u32 %r17, my_sdata; add.s32 %r5, %r17, %r16; st.shared.f32 [%r5], %f46; bar.sync 0; setp.lt.u32 %p6, %r10, 1024; - @%p6 bra BB99_8; + @%p6 bra BB103_8; setp.gt.u32 %p7, %r7, 511; - @%p7 bra BB99_7; + @%p7 bra BB103_7; ld.shared.f32 %f34, [%r5+2048]; add.f32 %f46, %f46, %f34; st.shared.f32 [%r5], %f46; -BB99_7: +BB103_7: bar.sync 0; -BB99_8: +BB103_8: setp.lt.u32 %p8, %r10, 512; - @%p8 bra BB99_12; + @%p8 bra BB103_12; setp.gt.u32 %p9, %r7, 255; - @%p9 bra BB99_11; + @%p9 bra BB103_11; ld.shared.f32 %f35, [%r5+1024]; add.f32 %f46, %f46, %f35; st.shared.f32 [%r5], %f46; -BB99_11: +BB103_11: bar.sync 0; -BB99_12: +BB103_12: setp.lt.u32 %p10, %r10, 256; - @%p10 bra BB99_16; + @%p10 bra BB103_16; setp.gt.u32 %p11, %r7, 127; - @%p11 bra BB99_15; + @%p11 bra BB103_15; ld.shared.f32 %f36, [%r5+512]; add.f32 %f46, %f46, %f36; st.shared.f32 [%r5], %f46; -BB99_15: +BB103_15: bar.sync 0; -BB99_16: +BB103_16: setp.lt.u32 %p12, %r10, 128; - @%p12 bra BB99_20; + @%p12 bra BB103_20; setp.gt.u32 %p13, %r7, 63; - @%p13 bra BB99_19; + @%p13 bra BB103_19; ld.shared.f32 %f37, [%r5+256]; add.f32 %f46, %f46, %f37; st.shared.f32 [%r5], %f46; -BB99_19: +BB103_19: bar.sync 0; -BB99_20: +BB103_20: setp.gt.u32 %p14, %r7, 31; - @%p14 bra BB99_33; + @%p14 bra BB103_33; setp.lt.u32 %p15, %r10, 64; - @%p15 bra BB99_23; + @%p15 bra BB103_23; ld.volatile.shared.f32 %f38, [%r5+128]; add.f32 %f46, %f46, %f38; st.volatile.shared.f32 [%r5], %f46; -BB99_23: +BB103_23: setp.lt.u32 %p16, %r10, 32; - @%p16 bra BB99_25; + @%p16 bra BB103_25; ld.volatile.shared.f32 %f39, [%r5+64]; add.f32 %f46, %f46, %f39; st.volatile.shared.f32 [%r5], %f46; -BB99_25: +BB103_25: setp.lt.u32 %p17, %r10, 16; - @%p17 bra BB99_27; + @%p17 bra BB103_27; ld.volatile.shared.f32 %f40, [%r5+32]; add.f32 %f46, %f46, %f40; st.volatile.shared.f32 [%r5], %f46; -BB99_27: +BB103_27: setp.lt.u32 %p18, %r10, 8; - @%p18 bra BB99_29; + @%p18 bra BB103_29; ld.volatile.shared.f32 %f41, [%r5+16]; add.f32 %f46, %f46, %f41; st.volatile.shared.f32 [%r5], %f46; -BB99_29: +BB103_29: setp.lt.u32 %p19, %r10, 4; - @%p19 bra BB99_31; + @%p19 bra BB103_31; ld.volatile.shared.f32 %f42, [%r5+8]; add.f32 %f46, %f46, %f42; st.volatile.shared.f32 [%r5], %f46; -BB99_31: +BB103_31: setp.lt.u32 %p20, %r10, 2; - @%p20 bra BB99_33; + @%p20 bra BB103_33; ld.volatile.shared.f32 %f43, [%r5+4]; add.f32 %f44, %f46, %f43; st.volatile.shared.f32 [%r5], %f44; -BB99_33: +BB103_33: setp.ne.s32 %p21, %r7, 0; - @%p21 bra BB99_35; + @%p21 bra BB103_35; ld.shared.f32 %f45, [my_sdata]; cvta.to.global.u64 %rd9, %rd2; @@ -12039,7 +12361,105 @@ BB99_33: add.s64 %rd11, %rd9, %rd10; st.global.f32 [%rd11], %f45; -BB99_35: +BB103_35: + ret; +} + + // .globl prepare_lstm_output_d +.visible .entry prepare_lstm_output_d( + .param .u64 prepare_lstm_output_d_param_0, + .param .u64 prepare_lstm_output_d_param_1, + .param .u32 prepare_lstm_output_d_param_2, + .param .u32 prepare_lstm_output_d_param_3, + .param .u32 prepare_lstm_output_d_param_4, + .param .u32 prepare_lstm_output_d_param_5 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<16>; + .reg .f64 %fd<2>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_output_d_param_0]; + ld.param.u64 %rd2, [prepare_lstm_output_d_param_1]; + ld.param.u32 %r2, [prepare_lstm_output_d_param_2]; + ld.param.u32 %r3, [prepare_lstm_output_d_param_3]; + ld.param.u32 %r4, [prepare_lstm_output_d_param_4]; + ld.param.u32 %r5, [prepare_lstm_output_d_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB104_2; + + cvta.to.global.u64 %rd3, %rd2; + mul.lo.s32 %r9, %r4, %r3; + div.s32 %r10, %r1, %r9; + rem.s32 %r11, %r1, %r9; + div.s32 %r12, %r11, %r4; + rem.s32 %r13, %r11, %r4; + mad.lo.s32 %r14, %r12, %r2, %r10; + mad.lo.s32 %r15, %r14, %r4, %r13; + mul.wide.s32 %rd4, %r15, 8; + add.s64 %rd5, %rd3, %rd4; + ld.global.f64 %fd1, [%rd5]; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + st.global.f64 [%rd8], %fd1; + +BB104_2: + ret; +} + + // .globl prepare_lstm_output_f +.visible .entry prepare_lstm_output_f( + .param .u64 prepare_lstm_output_f_param_0, + .param .u64 prepare_lstm_output_f_param_1, + .param .u32 prepare_lstm_output_f_param_2, + .param .u32 prepare_lstm_output_f_param_3, + .param .u32 prepare_lstm_output_f_param_4, + .param .u32 prepare_lstm_output_f_param_5 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<2>; + .reg .b32 %r<16>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_output_f_param_0]; + ld.param.u64 %rd2, [prepare_lstm_output_f_param_1]; + ld.param.u32 %r2, [prepare_lstm_output_f_param_2]; + ld.param.u32 %r3, [prepare_lstm_output_f_param_3]; + ld.param.u32 %r4, [prepare_lstm_output_f_param_4]; + ld.param.u32 %r5, [prepare_lstm_output_f_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB105_2; + + cvta.to.global.u64 %rd3, %rd2; + mul.lo.s32 %r9, %r4, %r3; + div.s32 %r10, %r1, %r9; + rem.s32 %r11, %r1, %r9; + div.s32 %r12, %r11, %r4; + rem.s32 %r13, %r11, %r4; + mad.lo.s32 %r14, %r12, %r2, %r10; + mad.lo.s32 %r15, %r14, %r4, %r13; + mul.wide.s32 %rd4, %r15, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f1, [%rd5]; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 4; + add.s64 %rd8, %rd6, %rd7; + st.global.f32 [%rd8], %f1; + +BB105_2: ret; } @@ -12048,7 +12468,7 @@ BB99_35: .param .b64 __internal_trig_reduction_slowpathd_param_1 ) { - .local .align 8 .b8 __local_depot100[40]; + .local .align 8 .b8 __local_depot106[40]; .reg .b64 %SP; .reg .b64 %SPL; .reg .pred %p<9>; @@ -12057,7 +12477,7 @@ BB99_35: .reg .b64 %rd<102>; - mov.u64 %rd101, __local_depot100; + mov.u64 %rd101, __local_depot106; cvta.local.u64 %SP, %rd101; ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0]; ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1]; @@ -12071,7 +12491,7 @@ BB99_35: shr.u32 %r3, %r1, 20; bfe.u32 %r4, %r1, 20, 11; setp.eq.s32 %p1, %r4, 2047; - @%p1 bra BB100_13; + @%p1 bra BB106_13; add.s32 %r15, %r4, -1024; shr.u32 %r16, %r15, 6; @@ -12084,7 +12504,7 @@ BB99_35: mov.u64 %rd94, 0; setp.ge.s32 %p2, %r5, %r6; mov.u64 %rd93, %rd1; - @%p2 bra BB100_4; + @%p2 bra BB106_4; mov.b64 %rd41, %fd4; shl.b64 %rd42, %rd41, 11; @@ -12101,7 +12521,7 @@ BB99_35: mov.u64 %rd91, %rd1; mov.u32 %r39, %r5; -BB100_3: +BB106_3: .pragma "nounroll"; ld.const.u64 %rd47, [%rd89]; // inline asm @@ -12131,15 +12551,15 @@ BB100_3: add.s64 %rd93, %rd93, 8; add.s64 %rd89, %rd89, 8; setp.lt.s32 %p3, %r39, %r6; - @%p3 bra BB100_3; + @%p3 bra BB106_3; -BB100_4: +BB106_4: st.local.u64 [%rd93], %rd94; ld.local.u64 %rd95, [%rd1+16]; ld.local.u64 %rd96, [%rd1+24]; and.b32 %r9, %r3, 63; setp.eq.s32 %p4, %r9, 0; - @%p4 bra BB100_6; + @%p4 bra BB106_6; mov.u32 %r27, 64; sub.s32 %r28, %r27, %r9; @@ -12151,7 +12571,7 @@ BB100_4: shr.u64 %rd55, %rd54, %r28; or.b64 %rd95, %rd55, %rd53; -BB100_6: +BB106_6: cvta.to.local.u64 %rd56, %rd37; shr.u64 %rd57, %rd96, 62; cvt.u32.u64 %r29, %rd57; @@ -12168,7 +12588,7 @@ BB100_6: selp.b32 %r34, %r32, %r33, %p5; st.local.u32 [%rd56], %r34; setp.eq.s32 %p6, %r31, 0; - @%p6 bra BB100_8; + @%p6 bra BB106_8; mov.u64 %rd64, 0; // inline asm @@ -12188,10 +12608,10 @@ BB100_6: // inline asm xor.b32 %r40, %r40, -2147483648; -BB100_8: +BB106_8: clz.b64 %r41, %rd98; setp.eq.s32 %p7, %r41, 0; - @%p7 bra BB100_10; + @%p7 bra BB106_10; shl.b64 %rd67, %rd98, %r41; mov.u32 %r35, 64; @@ -12199,7 +12619,7 @@ BB100_8: shr.u64 %rd68, %rd97, %r36; or.b64 %rd98, %rd68, %rd67; -BB100_10: +BB106_10: mov.u64 %rd72, -3958705157555305931; // inline asm { @@ -12220,7 +12640,7 @@ BB100_10: } // inline asm setp.lt.s64 %p8, %rd100, 1; - @%p8 bra BB100_12; + @%p8 bra BB106_12; // inline asm { @@ -12239,7 +12659,7 @@ BB100_10: // inline asm add.s32 %r41, %r41, 1; -BB100_12: +BB106_12: cvt.u64.u32 %rd79, %r40; shl.b64 %rd80, %rd79, 32; mov.u32 %r37, 1022; @@ -12254,7 +12674,7 @@ BB100_12: or.b64 %rd88, %rd87, %rd80; mov.b64 %fd4, %rd88; -BB100_13: +BB106_13: st.param.f64 [func_retval0+0], %fd4; ret; } @@ -12282,7 +12702,7 @@ BB100_13: } shr.u32 %r51, %r50, 20; setp.ne.s32 %p1, %r51, 0; - @%p1 bra BB101_2; + @%p1 bra BB107_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -12296,13 +12716,13 @@ BB100_13: shr.u32 %r16, %r50, 20; add.s32 %r51, %r16, -54; -BB101_2: +BB107_2: add.s32 %r52, %r51, -1023; and.b32 %r17, %r50, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd135, {%r49, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB101_4; + @%p2 bra BB107_4; { .reg .b32 %temp; @@ -12316,7 +12736,7 @@ BB101_2: mov.b64 %fd135, {%r19, %r21}; add.s32 %r52, %r51, -1022; -BB101_4: +BB107_4: add.f64 %fd15, %fd135, 0d3FF0000000000000; rcp.approx.ftz.f64 %fd16, %fd15; neg.f64 %fd17, %fd15; @@ -12479,13 +12899,13 @@ BB101_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB101_7; + @%p4 bra BB107_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd129, %fd4, 0d7FF0000000000000; selp.f64 %fd136, 0d0000000000000000, %fd129, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB101_7; + @%p6 bra BB107_7; mov.f64 %fd134, 0d4338000000000000; mov.f64 %fd133, 0d3FF71547652B82FE; @@ -12507,26 +12927,26 @@ BB101_4: mov.b64 %fd131, {%r44, %r43}; mul.f64 %fd136, %fd130, %fd131; -BB101_7: +BB107_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd136; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB101_9; + @%p7 bra BB107_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd136; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB101_10; + @%p8 bra BB107_10; -BB101_9: +BB107_9: fma.rn.f64 %fd136, %fd136, %fd5, %fd136; -BB101_10: +BB107_10: st.param.f64 [func_retval0+0], %fd136; ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index a405fd9..f3303de 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -248,7 +248,7 @@ public class ScriptExecutor { // set the GPUs to use for this process (a range, all GPUs, comma separated list or a specific GPU) GPUContextPool.AVAILABLE_GPUS = config.getTextValue(DMLConfig.AVAILABLE_GPUS); - + String evictionPolicy = config.getTextValue(DMLConfig.GPU_EVICTION_POLICY).toUpperCase(); try { DMLScript.GPU_EVICTION_POLICY = EvictionPolicy.valueOf(evictionPolicy); http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/conf/DMLConfig.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java index 0a896a3..9f08c3c 100644 --- a/src/main/java/org/apache/sysml/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java @@ -82,19 +82,19 @@ public class DMLConfig public static final String CODEGEN_PLANCACHE = "sysml.codegen.plancache"; //boolean public static final String CODEGEN_LITERALS = "sysml.codegen.literals"; //1..heuristic, 2..always public static final String CACHING_BUFFER_SIZE = "sysml.caching.bufferSize"; //double: default:0.15 - public static final String EXTRA_FINEGRAINED_STATS = "sysml.stats.finegrained"; //boolean public static final String STATS_MAX_WRAP_LEN = "sysml.stats.maxWrapLength"; //int public static final String AVAILABLE_GPUS = "sysml.gpu.availableGPUs"; // String to specify which GPUs to use (a range, all GPUs, comma separated list or a specific GPU) public static final String SYNCHRONIZE_GPU = "sysml.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after every instruction public static final String EAGER_CUDA_FREE = "sysml.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on rmvar public static final String GPU_EVICTION_POLICY = "sysml.gpu.eviction.policy"; // string: can be lru, lfu, min_evict + // Fraction of available memory to use. The available memory is computer when the GPUContext is created // to handle the tradeoff on calling cudaMemGetInfo too often. public static final String GPU_MEMORY_UTILIZATION_FACTOR = "sysml.gpu.memory.util.factor"; public static final String FLOATING_POINT_PRECISION = "sysml.floating.point.precision"; // String to specify the datatype to use internally: supported values are double, single public static final String PRINT_GPU_MEMORY_INFO = "sysml.gpu.print.memoryInfo"; - + // supported prefixes for custom map/reduce configurations public static final String PREFIX_MAPRED = "mapred"; public static final String PREFIX_MAPREDUCE = "mapreduce"; @@ -135,13 +135,13 @@ public class DMLConfig _defaultVals.put(NATIVE_BLAS, "none" ); _defaultVals.put(NATIVE_BLAS_DIR, "none" ); _defaultVals.put(EXTRA_FINEGRAINED_STATS,"false" ); + _defaultVals.put(PRINT_GPU_MEMORY_INFO, "false" ); _defaultVals.put(STATS_MAX_WRAP_LEN, "30" ); _defaultVals.put(GPU_MEMORY_UTILIZATION_FACTOR, "0.9" ); _defaultVals.put(AVAILABLE_GPUS, "-1"); _defaultVals.put(GPU_EVICTION_POLICY, "align_memory"); _defaultVals.put(SYNCHRONIZE_GPU, "false" ); _defaultVals.put(CACHING_BUFFER_SIZE, "0.15" ); - _defaultVals.put(SYNCHRONIZE_GPU, "true" ); _defaultVals.put(EAGER_CUDA_FREE, "false" ); _defaultVals.put(FLOATING_POINT_PRECISION, "double" ); _defaultVals.put(PRINT_GPU_MEMORY_INFO, "false"); http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index 3631f00..ec2fda8 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -168,6 +168,10 @@ public class FunctionOp extends Hop long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0); return outputVectors+outputValues; } + else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) { + // TODO: To allow for initial version to always run on the GPU + return 0; + } else if ( getFunctionName().equalsIgnoreCase("svd") ) { long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0); @@ -198,6 +202,10 @@ public class FunctionOp extends Hop return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) + 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); } + else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + // TODO: To allow for initial version to always run on the GPU + return 0; + } else if ( getFunctionName().equalsIgnoreCase("svd")) { double interOutput = OptimizerUtils.estimateSizeExactSparsity(1, getInput().get(0).getDim2(), 1.0); return interOutput; @@ -215,7 +223,10 @@ public class FunctionOp extends Hop @Override public boolean isGPUEnabled() { - return false; + if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) + return true; + else + return false; } @Override @@ -253,6 +264,7 @@ public class FunctionOp extends Hop protected ExecType optFindExecType() { checkAndSetForcedPlatform(); + ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) { @@ -261,7 +273,17 @@ public class FunctionOp extends Hop _etype = ((_etypeForced==ExecType.SPARK || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); - } + } + else if( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { +// if ( OptimizerUtils.isMemoryBasedOptLevel() ) { +// _etype = findExecTypeByMemEstimate(); +// } +// else { +// _etype = ExecType.CP; +// } +// _etype = _etype == REMOTE ? ExecType.CP : _etype; // lstm not supported on Spark + _etype = ExecType.GPU; + } else { // Since the memory estimate is only conservative, do not throw // exception if the estimated memory is larger than the budget http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/hops/ReorgOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java b/src/main/java/org/apache/sysml/hops/ReorgOp.java index 70fb797..d01ed09 100644 --- a/src/main/java/org/apache/sysml/hops/ReorgOp.java +++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java @@ -27,7 +27,6 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.lops.Aggregate; import org.apache.sysml.lops.Group; import org.apache.sysml.lops.Lop; -import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Transform; import org.apache.sysml.lops.LopProperties.ExecType; @@ -132,18 +131,18 @@ public class ReorgOp extends Hop implements MultiThreadedHop return false; switch( op ) { case TRANS: { - Lop lin; - try { - lin = getInput().get(0).constructLops(); - } catch (HopsException | LopsException e) { - throw new RuntimeException("Unable to create child lop", e); + if( getDim1()==1 && getDim2()==1 ) { + return false; //if input of size 1x1, avoid unnecessary transpose } - if( lin instanceof Transform && ((Transform)lin).getOperationType()==OperationTypes.Transpose ) + else if( getInput().get(0) instanceof ReorgOp && ((ReorgOp) getInput().get(0)).getOp() == ReOrgOp.TRANS) { + // Following checks causes stackoverflow: + // lin = getInput().get(0).constructLops(); + // lin instanceof Transform && ((Transform)lin).getOperationType()==OperationTypes.Transpose return false; //if input is already a transpose, avoid redundant transpose ops - else if( getDim1()==1 && getDim2()==1 ) - return false; //if input of size 1x1, avoid unnecessary transpose - else + } + else { return true; + } } case DIAG: case REV: http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index ea51bd1..8a0a6d8 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -89,6 +89,18 @@ public class BuiltinFunctionExpression extends DataIdentifier public Expression getThirdExpr() { return (_args.length >= 3 ? _args[2] : null); } + + public Expression getFourthExpr() { + return (_args.length >= 4 ? _args[3] : null); + } + + public Expression getFifthExpr() { + return (_args.length >= 5 ? _args[4] : null); + } + + public Expression getSixthExpr() { + return (_args.length >= 5 ? _args[4] : null); + } public Expression[] getAllExpr(){ return _args; @@ -106,17 +118,14 @@ public class BuiltinFunctionExpression extends DataIdentifier } this.getFirstExpr().validateExpression(ids, constVars, conditional); - if (getSecondExpr() != null){ - if (this.getSecondExpr() instanceof FunctionCallIdentifier){ - raiseValidateError("UDF function call not supported as parameter to built-in function call", false); - } - getSecondExpr().validateExpression(ids, constVars, conditional); - } - if (getThirdExpr() != null) { - if (this.getThirdExpr() instanceof FunctionCallIdentifier){ - raiseValidateError("UDF function call not supported as parameter to built-in function call", false); + Expression [] expr = getAllExpr(); + if(expr != null && expr.length > 1) { + for(int i = 1; i < expr.length; i++) { + if (expr[i] instanceof FunctionCallIdentifier){ + raiseValidateError("UDF function call not supported as parameter to built-in function call", false); + } + expr[i].validateExpression(ids, constVars, conditional); } - getThirdExpr().validateExpression(ids, constVars, conditional); } _outputs = new Identifier[stmt.getTargetList().size()]; int count = 0; @@ -189,6 +198,97 @@ public class BuiltinFunctionExpression extends DataIdentifier break; + case LSTM: + { + // X, W, bias, out0, c0, return_sequences + checkNumParameters(6); + checkMatrixParam(getFirstExpr()); + checkMatrixParam(getSecondExpr()); + checkMatrixParam(getThirdExpr()); + checkMatrixParam(getFourthExpr()); + checkMatrixParam(getFifthExpr()); + + // setup output properties + if(getOutputs() == null || getOutputs().length != 3) { + int numOutputs = getOutputs() == null ? 0 : getOutputs().length; + raiseValidateError("The builtin function lstm has three outputs, but instead found: " + numOutputs, conditional); + } + DataIdentifier out = (DataIdentifier) getOutputs()[0]; + DataIdentifier cy = (DataIdentifier) getOutputs()[1]; + DataIdentifier reserveSpace = (DataIdentifier) getOutputs()[2]; + + // Output1 - out: If `return_sequences` is True, outputs for all timesteps, else outputs for the final timestep. + out.setDataType(DataType.MATRIX); + out.setValueType(ValueType.DOUBLE); + out.setDimensions(-1, -1); + out.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock()); + + // Output2 - Cell state for final timestep. + cy.setDataType(DataType.MATRIX); + cy.setValueType(ValueType.DOUBLE); + cy.setDimensions(getExpr(4).getOutput().getDim1(), getExpr(4).getOutput().getDim2()); + cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), getExpr(4).getOutput().getColumnsInBlock()); + + // Output3 - reserve space. + reserveSpace.setDataType(DataType.MATRIX); + reserveSpace.setValueType(ValueType.DOUBLE); + reserveSpace.setDimensions(-1, -1); + reserveSpace.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock()); + + break; + } + case BATCH_NORM2D: + { + // Input: image, scale, bias, runningMean, runningVar, mode, epsilon, exponentialAverageFactor + checkNumParameters(8); + checkMatrixParam(getFirstExpr()); + checkMatrixParam(getSecondExpr()); + checkMatrixParam(getThirdExpr()); + checkMatrixParam(getFourthExpr()); + checkMatrixParam(getFifthExpr()); + + // Output: ret, retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance + // setup output properties + if(getOutputs().length != 5) + raiseValidateError("batch_norm2d has 5 outputs", false); + + DataIdentifier ret = (DataIdentifier) getOutputs()[0]; + DataIdentifier retRunningMean = (DataIdentifier) getOutputs()[1]; + DataIdentifier retRunningVar = (DataIdentifier) getOutputs()[2]; + DataIdentifier resultSaveMean = (DataIdentifier) getOutputs()[3]; + DataIdentifier resultSaveInvVariance = (DataIdentifier) getOutputs()[4]; + + setDimensions(ret, getFirstExpr()); + setDimensions(retRunningMean, getFourthExpr()); + setDimensions(retRunningVar, getFourthExpr()); + setDimensions(resultSaveMean, getFourthExpr()); + setDimensions(resultSaveInvVariance, getFourthExpr()); + break; + } + case BATCH_NORM2D_BACKWARD: + { + // Input: image, dout, scale, epsilon, savedMean, savedInvVariance + checkNumParameters(6); + checkMatrixParam(getFirstExpr()); + checkMatrixParam(getSecondExpr()); + checkMatrixParam(getThirdExpr()); + checkMatrixParam(getFifthExpr()); + checkMatrixParam(getSixthExpr()); + + // Output: dX, dScale, dBias + // setup output properties + if(getOutputs().length != 3) + raiseValidateError("batch_norm2d_backward has 3 outputs", false); + + DataIdentifier dX = (DataIdentifier) getOutputs()[0]; + DataIdentifier dScale = (DataIdentifier) getOutputs()[1]; + DataIdentifier dBias = (DataIdentifier) getOutputs()[2]; + + setDimensions(dX, getFirstExpr()); + setDimensions(dScale, getThirdExpr()); + setDimensions(dBias, getThirdExpr()); + break; + } case EIGEN: checkNumParameters(1); checkMatrixParam(getFirstExpr()); @@ -251,6 +351,13 @@ public class BuiltinFunctionExpression extends DataIdentifier } } + private static void setDimensions(DataIdentifier out, Expression exp) { + out.setDataType(DataType.MATRIX); + out.setValueType(ValueType.DOUBLE); + out.setDimensions(exp.getOutput().getDim1(), exp.getOutput().getDim2()); + out.setBlockDimensions(exp.getOutput().getRowsInBlock(), exp.getOutput().getColumnsInBlock()); + } + private static ArrayList<ParameterExpression> orderConvolutionParams(ArrayList<ParameterExpression> paramExpression, int skip) { ArrayList<ParameterExpression> newParams = new ArrayList<>(); @@ -1300,7 +1407,8 @@ public class BuiltinFunctionExpression extends DataIdentifier else { // always unconditional (because unsupported operation) BuiltinFunctionOp op = getOpCode(); - if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD) + if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD + || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD) raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS); else raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS); @@ -1362,6 +1470,9 @@ public class BuiltinFunctionExpression extends DataIdentifier case QR: case LU: case EIGEN: + case LSTM: + case BATCH_NORM2D: + case BATCH_NORM2D_BACKWARD: case SVD: return true; default: @@ -1503,7 +1614,7 @@ public class BuiltinFunctionExpression extends DataIdentifier protected void checkMatrixParam(Expression e) { if (e.getOutput().getDataType() != DataType.MATRIX) { raiseValidateError( - "Expecting matrix argument for function " + this.getOpCode().toString().toLowerCase() + "().", + "Expected " + e.getText() + " to be a matrix argument for function " + this.getOpCode().toString().toLowerCase() + "().", false); } } @@ -1771,6 +1882,12 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.LU; else if (functionName.equals("eigen")) bifop = Expression.BuiltinFunctionOp.EIGEN; + else if (functionName.equals("lstm")) + bifop = Expression.BuiltinFunctionOp.LSTM; + else if (functionName.equals("batch_norm2d")) + bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D; + else if (functionName.equals("batch_norm2d_backward")) + bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D_BACKWARD; else if (functionName.equals("conv2d")) bifop = Expression.BuiltinFunctionOp.CONV2D; else if (functionName.equals("bias_add")) http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index ff02012..fd0b47b 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2212,10 +2212,12 @@ public class DMLTranslator // Construct Hops for all inputs ArrayList<Hop> inputs = new ArrayList<>(); inputs.add( processExpression(source.getFirstExpr(), null, hops) ); - if ( source.getSecondExpr() != null ) - inputs.add( processExpression(source.getSecondExpr(), null, hops) ); - if ( source.getThirdExpr() != null ) - inputs.add( processExpression(source.getThirdExpr(), null, hops) ); + Expression[] expr = source.getAllExpr(); + if(expr != null && expr.length > 1) { + for(int i = 1; i < expr.length; i++) { + inputs.add( processExpression(expr[i], null, hops) ); + } + } FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; String nameSpace = DMLProgram.INTERNAL_NAMESPACE; @@ -2230,6 +2232,9 @@ public class DMLTranslator case QR: case LU: case EIGEN: + case LSTM: + case BATCH_NORM2D: + case BATCH_NORM2D_BACKWARD: case SVD: // Number of outputs = size of targetList = #of identifiers in source.getOutputs http://git-wip-us.apache.org/repos/asf/systemml/blob/276065f9/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 9a6ea64..4355bd5 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -92,6 +92,7 @@ public abstract class Expression implements ParseInfo EXISTS, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, BIAS_MULTIPLY, MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD, + LSTM, BATCH_NORM2D, BATCH_NORM2D_BACKWARD, EXP, FLOOR, IFELSE,