kaknikhil commented on a change in pull request #564:
URL: https://github.com/apache/madlib/pull/564#discussion_r614402381



##########
File path: src/modules/convex/type/state.hpp
##########
@@ -819,11 +819,13 @@ class MLPMiniBatchState {
      * - N + 4: is_classification (do classification)
      * - N + 5: activation (activation function)
      * - N + 6: coeff (coefficients, design doc: u)
+     * - N + 7: momentum

Review comment:
       Why do we need to add momentum and nesterov to MLPMiniBatchState as part 
of this PR?

##########
File path: src/ports/postgres/modules/convex/mlp_igd.py_in
##########
@@ -63,7 +63,7 @@ from utilities.utilities import 
create_cols_from_array_sql_string
 
 
 
-@MinWarning("error")
+# @MinWarning("warning")

Review comment:
       why did we change this to warning ?

##########
File path: src/ports/postgres/modules/convex/mlp.sql_in
##########
@@ -383,6 +384,15 @@ the parameter is ignored.
 \b Optimizer \b Parameters
 <DL class="arglist">
 
+<DT>solver</dt>
+<DD>Default: sgd.
+One of 'sgd', 'rmsprop' or 'adam' or any prefix of these (e.g., 'rmsp' means 
'rmsprop').

Review comment:
       Do you think it makes sense to validate that the arguments that are 
passed in actually apply to the solver ? For ex should we throw an error or 
warning ? Some examples
   1. beta1 or beta2 are passed along with sgd or rmsprop
   2. momentum is passed for rmsprop or adam
   
   

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &gamma,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;
+
+    // initialize gradient vector
+    std::vector<Matrix> total_gradient_per_layer(model.num_layers);
+    Matrix g, v_bias_corr, sqr_bias_corr;
+    for (Index k=0; k < model.num_layers; ++k) {
+        total_gradient_per_layer[k] = Matrix::Zero(model.u[k].rows(),
+                                                   model.u[k].cols());
+    }
+
+    std::vector<ColumnVector> net, o, delta;
+    Index num_rows_in_batch = x_batch.rows();
+
+    for (Index i=0; i < num_rows_in_batch; i++){
+        // gradient and loss added over the batch
+        ColumnVector x = x_batch.row(i);
+        ColumnVector y_true = y_true_batch.row(i);
+
+        feedForward(model, x, net, o);
+        backPropogate(y_true, o.back(), net, model, delta);
+
+        // compute the gradient
+        for (Index k=0; k < model.num_layers; k++){
+                total_gradient_per_layer[k] += o[k] * delta[k].transpose();
+        }
+
+        // compute the loss
+        total_loss += getLoss(y_true, o.back(), model.is_classification);
+    }
+
+    for (Index k=0; k < model.num_layers; k++){
+        // convert gradient to a gradient update vector
+        //  1. normalize to per row update
+        //  2. discount by stepsize
+        //  3. add regularization
+        Matrix regularization = MLP<Model, Tuple>::lambda * model.u[k];
+        regularization.row(0).setZero(); // Do not update bias
+
+        g = total_gradient_per_layer[k] / 
static_cast<double>(num_rows_in_batch);
+        if (opt_code == IS_RMSPROP){
+
+            sqrs[k] = gamma * sqrs[k] + (1.0 - gamma) * square(g);
+            total_gradient_per_layer[k] = (-stepsize * g).array() /
+                                          (sqrs[k].array() + 
EPS_STABLE).sqrt();

Review comment:
       I noticed that in some places like 
https://towardsdatascience.com/a-look-at-gradient-descent-and-rmsprop-optimizers-f77d483ef08b
 , they take the square root and then add epsilon `sqrt(sqrs[k]) + EPS_STABLE` 
. I believe keras also does the same thing (see 
https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/keras/optimizer_v2/rmsprop.py#L219
 and 
https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/keras/optimizer_v2/rmsprop.py#L275).
 
   whereas in other places like 
https://mxnet.apache.org/versions/master/api/python/docs/api/optimizer/index.html#mxnet.optimizer.RMSProp
 they add epsilon and then take square root of the whole thing `sqrt(sqrs[k] + 
EPS_STABLE)` 
   
   
   This also ties back to the default value for `EPS_STABLE`. We have a couple 
of options
   
   1. Use the same default value for both adam and rmsprop. Keep epsilon 
outside the square root  for both adam and rmsprop and change the default value 
of epsilon to either 1e-7 or 1e-8 for both. 
   2. Keep the calculations as is and use different default values for adam and 
rmsprop. For rmsprop (with epsilon inside the square root), something like 
1e-18 or 1e-16 might work and for adam (with epsilon outside the square root), 
something like 1e-7 or 1e-8 might make sense.
   
   IMO, option 1 might be a better choice given that we might also want to 
expose the epsilon argument to the user and different default values might 
confuse the user.

##########
File path: src/ports/postgres/modules/convex/mlp.sql_in
##########
@@ -379,7 +379,10 @@ the parameter is ignored.
    batch_size = &lt;value>,
    n_epochs = &lt;value>,
    momentum = &lt;value>,
-   nesterov = &lt;value>'
+   nesterov = &lt;value>',
+   beta = &lt;value>',

Review comment:
       We might want to consider calling the new argument as rho. 
[keras](https://keras.io/api/optimizers/rmsprop/) and the master branch of 
[mxnet](https://mxnet.apache.org/versions/master/api/python/docs/api/optimizer/index.html#mxnet.optimizer.RMSProp)
 name it as `rho` so maybe rho is a better choice. (Note that previous versions 
of 
[mxnet](https://mxnet.apache.org/versions/1.8.0/api/python/docs/api/optimizer/index.html#mxnet.optimizer.RMSProp)
 named rho as `gamma1` but looks like the master branch changed it to `rho`)

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &gamma,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;
+
+    // initialize gradient vector
+    std::vector<Matrix> total_gradient_per_layer(model.num_layers);
+    Matrix g, v_bias_corr, sqr_bias_corr;
+    for (Index k=0; k < model.num_layers; ++k) {
+        total_gradient_per_layer[k] = Matrix::Zero(model.u[k].rows(),
+                                                   model.u[k].cols());
+    }
+
+    std::vector<ColumnVector> net, o, delta;
+    Index num_rows_in_batch = x_batch.rows();
+
+    for (Index i=0; i < num_rows_in_batch; i++){
+        // gradient and loss added over the batch
+        ColumnVector x = x_batch.row(i);
+        ColumnVector y_true = y_true_batch.row(i);
+
+        feedForward(model, x, net, o);
+        backPropogate(y_true, o.back(), net, model, delta);
+
+        // compute the gradient
+        for (Index k=0; k < model.num_layers; k++){
+                total_gradient_per_layer[k] += o[k] * delta[k].transpose();
+        }
+
+        // compute the loss
+        total_loss += getLoss(y_true, o.back(), model.is_classification);
+    }
+
+    for (Index k=0; k < model.num_layers; k++){
+        // convert gradient to a gradient update vector
+        //  1. normalize to per row update
+        //  2. discount by stepsize
+        //  3. add regularization
+        Matrix regularization = MLP<Model, Tuple>::lambda * model.u[k];
+        regularization.row(0).setZero(); // Do not update bias
+
+        g = total_gradient_per_layer[k] / 
static_cast<double>(num_rows_in_batch);
+        if (opt_code == IS_RMSPROP){
+
+            sqrs[k] = gamma * sqrs[k] + (1.0 - gamma) * square(g);
+            total_gradient_per_layer[k] = (-stepsize * g).array() /
+                                          (sqrs[k].array() + 
EPS_STABLE).sqrt();

Review comment:
       Do you think it makes sense to add a reference to either our design doc 
or the official literature here ?

##########
File path: src/ports/postgres/modules/convex/mlp_igd.py_in
##########
@@ -824,6 +861,22 @@ def _get_learning_rate_policy_name(learning_rate_policy):
                        ', '.join(supported_learning_rate_policies)))
     return learning_rate_policy
 
+def _get_opt_code(solver):
+    if not solver:

Review comment:
       Although we allow partial string matching for the solver, should we 
store the full name in the summary table ? Currently we store whatever the user 
passes in 

##########
File path: src/ports/postgres/modules/convex/mlp_igd.py_in
##########
@@ -633,7 +667,11 @@ def _get_optimizer_params(param_str):
         "n_epochs": (1, int),
         "batch_size": (1, int),
         "momentum": (0.9, float),
-        "nesterov": (True, bool)
+        "nesterov": (True, bool),
+        "is_rmsprop": (False, bool),
+        "is_adam": (False, bool),
+        "beta1": (0.9, float),

Review comment:
       Shouldn't we also validate the range for beta, beta1 and beta2 just like 
we do for others in `_validate_args` ?

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &gamma,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;
+
+    // initialize gradient vector
+    std::vector<Matrix> total_gradient_per_layer(model.num_layers);
+    Matrix g, v_bias_corr, sqr_bias_corr;
+    for (Index k=0; k < model.num_layers; ++k) {
+        total_gradient_per_layer[k] = Matrix::Zero(model.u[k].rows(),
+                                                   model.u[k].cols());
+    }
+
+    std::vector<ColumnVector> net, o, delta;
+    Index num_rows_in_batch = x_batch.rows();
+
+    for (Index i=0; i < num_rows_in_batch; i++){
+        // gradient and loss added over the batch
+        ColumnVector x = x_batch.row(i);
+        ColumnVector y_true = y_true_batch.row(i);
+
+        feedForward(model, x, net, o);
+        backPropogate(y_true, o.back(), net, model, delta);
+
+        // compute the gradient
+        for (Index k=0; k < model.num_layers; k++){
+                total_gradient_per_layer[k] += o[k] * delta[k].transpose();
+        }
+
+        // compute the loss
+        total_loss += getLoss(y_true, o.back(), model.is_classification);
+    }
+
+    for (Index k=0; k < model.num_layers; k++){
+        // convert gradient to a gradient update vector
+        //  1. normalize to per row update
+        //  2. discount by stepsize
+        //  3. add regularization
+        Matrix regularization = MLP<Model, Tuple>::lambda * model.u[k];
+        regularization.row(0).setZero(); // Do not update bias
+
+        g = total_gradient_per_layer[k] / 
static_cast<double>(num_rows_in_batch);
+        if (opt_code == IS_RMSPROP){
+
+            sqrs[k] = gamma * sqrs[k] + (1.0 - gamma) * square(g);

Review comment:
       In 
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf, the 
mean square is multiplied by 0.9 whereas in our code it will be multiplied by 
0.1 since the default value of gamma is 0.1

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &gamma,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;

Review comment:
       How did we decide on the default value of `1.e-18`? keras/tf defaults it 
to `1e-07` and mxnet defaults it to `1e-8`

##########
File path: src/ports/postgres/modules/convex/mlp_igd.py_in
##########
@@ -296,7 +305,30 @@ def mlp(schema_madlib, source_table, output_table, 
independent_varname,
                     step_size = step_size_init * gamma**(
                         math.floor(it.iteration / iterations_per_step))
                 it.kwargs['step_size'] = step_size
-                if is_minibatch_enabled:
+                if opt_code != 0 :

Review comment:
       I think it's a bit confusing to read this since the reader would need to 
understand that sgd stands for 0, rmsprop stands for 1 and adam for 2. One 
alternative is to use an enum so that we can do something like `if opt_code != 
Solver.SGD` etc. I'll leave it up to you to decide

##########
File path: src/modules/convex/algo/igd.hpp
##########
@@ -36,6 +36,7 @@ class IGD {
     static void transition(state_type &state, const tuple_type &tuple);
     static void merge(state_type &state, const_state_type &otherState);
     static void transitionInMiniBatch(state_type &state, const tuple_type 
&tuple);
+    static void transitionInALR(state_type &state, const tuple_type &tuple);

Review comment:
       I would suggest instead of the `ALR` acronym, we use something more 
descriptive like `transitionInMiniBatchWithAdaptiveLearning` (or something 
similar). What do you think ?

##########
File path: src/ports/postgres/modules/convex/mlp.sql_in
##########
@@ -1467,6 +1477,31 @@ RETURNS DOUBLE PRECISION[]
 AS 'MODULE_PATHNAME'
 LANGUAGE C IMMUTABLE;
 
+CREATE FUNCTION MADLIB_SCHEMA.mlp_alr_transition(
+        state              DOUBLE PRECISION[],
+        ind_var            DOUBLE PRECISION[],
+        dep_var            DOUBLE PRECISION[],
+        previous_state     DOUBLE PRECISION[],
+        layer_sizes        DOUBLE PRECISION[],
+        learning_rate_init DOUBLE PRECISION,
+        activation         INTEGER,
+        is_classification  INTEGER,
+        weight             DOUBLE PRECISION,
+        warm_start_coeff   DOUBLE PRECISION[],
+        lambda             DOUBLE PRECISION,
+        batch_size         INTEGER,
+        n_epochs           INTEGER,
+        momentum           DOUBLE PRECISION,
+        is_nesterov        BOOLEAN,
+        opt_code           INTEGER,
+        gamma              DOUBLE PRECISION,

Review comment:
       We need to introduce a new argument for the `decay factor of moving 
average over past squared gradient` constant used in rmsprop. There is a 
difference between gamma which is the `Decay rate for learning rate` and 
rho/beta which is the `decay factor of moving average over past squared 
gradient used in rmsprop`, these two decay different things
   
   We can call the new argument as beta or rho. 
[keras](https://keras.io/api/optimizers/rmsprop/) and the master branch of 
[mxnet](https://mxnet.apache.org/versions/master/api/python/docs/api/optimizer/index.html#mxnet.optimizer.RMSProp)
 name it as `rho` so maybe rho is a better choice since that way it won't get 
confused with beta1 or beta1 args of adam. (Note that previous versions of 
[mxnet](https://mxnet.apache.org/versions/1.8.0/api/python/docs/api/optimizer/index.html#mxnet.optimizer.RMSProp)
 named rho as `gamma1` but looks like the master branch changed it to `rho`)

##########
File path: src/ports/postgres/modules/convex/mlp_igd.py_in
##########
@@ -106,6 +106,7 @@ def mlp(schema_madlib, source_table, output_table, 
independent_varname,
     n_epochs = optimizer_params['n_epochs']
     momentum = optimizer_params['momentum']
     is_nesterov = optimizer_params['nesterov']
+    beta = optimizer_params['beta']

Review comment:
       We should also add these three new params to the `mlp_help` function

##########
File path: doc/design/modules/neural-network.tex
##########
@@ -142,6 +142,31 @@ \subsubsection{Backpropagation}
 \end{aligned}\]
 where $u$ is the coefficient vector, $v$ is the velocity vector, $\mu$ is the 
momentum value, $\eta$ is the learning rate and $\frac{\partial f}{\partial 
ua_{k-1}^{sj}}$ is the gradient calculated at the updated position $ua$
 
+\subsection{Adaptive Learning Rates}
+
+
+\paragraph{RMSprop.}
+RMSprop is an unpublished optimization algorithm, first proposed by Geoff 
Hinton \cite{rmsprop_hinton}. RMSprop works by keeping a moving average of the 
squared gradient for each weight.
+
+\[\begin{aligned}

Review comment:
       We should also mention the epsilon constant for rmsprop

##########
File path: src/ports/postgres/modules/convex/test/mlp.sql_in
##########
@@ -1538,3 +1539,67 @@ SELECT mlp_predict(
     'id',
     'mlp_prediction_batch_output',
     'output');
+
+-- Test rmsprop
+DROP TABLE IF EXISTS mlp_class_batch, mlp_class_batch_summary, 
mlp_class_batch_standardization;
+
+SELECT mlp_classification(
+    'iris_data_batch',    -- Source table
+    'mlp_class_batch',    -- Destination table
+    'independent_varname',   -- Input features
+    'dependent_varname',        -- Label
+    ARRAY[5],   -- Number of units per layer
+    'learning_rate_init=0.1,
+    learning_rate_policy=constant,
+    n_iterations=5,
+    tolerance=0,
+    batch_size=5,
+    n_epochs=15,
+    gamma=0.9,

Review comment:
       Did we mean `beta` instead of `gamma` ?

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &beta,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;
+
+    // initialize gradient vector
+    std::vector<Matrix> total_gradient_per_layer(model.num_layers);
+    Matrix g, v_bias_corr, sqr_bias_corr;
+    for (Index k=0; k < model.num_layers; ++k) {
+        total_gradient_per_layer[k] = Matrix::Zero(model.u[k].rows(),
+                                                   model.u[k].cols());
+    }
+
+    std::vector<ColumnVector> net, o, delta;
+    Index num_rows_in_batch = x_batch.rows();
+
+    for (Index i=0; i < num_rows_in_batch; i++){
+        // gradient and loss added over the batch
+        ColumnVector x = x_batch.row(i);
+        ColumnVector y_true = y_true_batch.row(i);
+
+        feedForward(model, x, net, o);
+        backPropogate(y_true, o.back(), net, model, delta);
+
+        // compute the gradient
+        for (Index k=0; k < model.num_layers; k++){
+                total_gradient_per_layer[k] += o[k] * delta[k].transpose();
+        }
+
+        // compute the loss
+        total_loss += getLoss(y_true, o.back(), model.is_classification);
+    }
+
+    for (Index k=0; k < model.num_layers; k++){
+        // convert gradient to a gradient update vector
+        //  1. normalize to per row update
+        //  2. discount by stepsize
+        //  3. add regularization
+        Matrix regularization = MLP<Model, Tuple>::lambda * model.u[k];
+        regularization.row(0).setZero(); // Do not update bias
+
+        g = total_gradient_per_layer[k] / 
static_cast<double>(num_rows_in_batch);
+        if (opt_code == IS_RMSPROP){
+
+            sqrs[k] = beta * sqrs[k] + (1.0 - beta) * square(g);
+            total_gradient_per_layer[k] = (-stepsize * g).array() /
+                                          (sqrs[k].array() + 
EPS_STABLE).sqrt();
+        }
+        else if (opt_code == IS_ADAM){
+
+            vs[k] = beta1 * vs[k] + (1.0-beta1) * g;

Review comment:
       Instead of calling the variable sqrs and vs, should we call it m and v 
to match the formula cited in the design doc and other literature ?

##########
File path: src/ports/postgres/modules/convex/test/mlp.sql_in
##########
@@ -1538,3 +1539,67 @@ SELECT mlp_predict(
     'id',
     'mlp_prediction_batch_output',
     'output');
+
+-- Test rmsprop
+DROP TABLE IF EXISTS mlp_class_batch, mlp_class_batch_summary, 
mlp_class_batch_standardization;
+
+SELECT mlp_classification(
+    'iris_data_batch',    -- Source table
+    'mlp_class_batch',    -- Destination table
+    'independent_varname',   -- Input features
+    'dependent_varname',        -- Label
+    ARRAY[5],   -- Number of units per layer
+    'learning_rate_init=0.1,
+    learning_rate_policy=constant,
+    n_iterations=5,
+    tolerance=0,
+    batch_size=5,
+    n_epochs=15,
+    gamma=0.9,
+    solver=rmsprop',
+    'sigmoid',
+    '',
+    FALSE,           -- Warm start
+    FALSE
+);

Review comment:
       Should we also test the summary table contents ?

##########
File path: src/modules/convex/task/mlp.hpp
##########
@@ -218,6 +237,86 @@ MLP<Model, Tuple>::getLossAndUpdateModel(
     return total_loss;
 }
 
+template <class Model, class Tuple>
+double
+MLP<Model, Tuple>::getLossAndUpdateModelALR(
+        model_type              &model,
+        const Matrix            &x_batch,
+        const Matrix            &y_true_batch,
+        const double            &stepsize,
+        const int               &opt_code,
+        const double            &gamma,
+        std::vector<Matrix>     &sqrs,
+        const double            &beta1,
+        const double            &beta2,
+        std::vector<Matrix>     &vs,
+        const int               &t) {
+
+    double total_loss = 0.;
+    const double EPS_STABLE = 1.e-18;

Review comment:
       Also I think we should allow the user to send in their own epsilon value 
just like keras/tf and mxnet. What do you think ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to