Github user kaknikhil commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/243#discussion_r175895832
  
    --- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in ---
    @@ -72,107 +73,127 @@ def mlp(schema_madlib, source_table, output_table, 
independent_varname,
         """
         warm_start = bool(warm_start)
         optimizer_params = _get_optimizer_params(optimizer_param_str or "")
    +
    +    tolerance = optimizer_params["tolerance"]
    +    n_iterations = optimizer_params["n_iterations"]
    +    step_size_init = optimizer_params["learning_rate_init"]
    +    iterations_per_step = optimizer_params["iterations_per_step"]
    +    power = optimizer_params["power"]
    +    gamma = optimizer_params["gamma"]
    +    step_size = step_size_init
    +    n_tries = optimizer_params["n_tries"]
    +    # lambda is a reserved word in python
    +    lmbda = optimizer_params["lambda"]
    +    batch_size = optimizer_params['batch_size']
    +    n_epochs = optimizer_params['n_epochs']
    +
         summary_table = add_postfix(output_table, "_summary")
         standardization_table = add_postfix(output_table, "_standardization")
    -    weights = '1' if not weights or not weights.strip() else 
weights.strip()
         hidden_layer_sizes = hidden_layer_sizes or []
     
    -    grouping_col = grouping_col or ""
    -    activation = _get_activation_function_name(activation)
    -    learning_rate_policy = _get_learning_rate_policy_name(
    -        optimizer_params["learning_rate_policy"])
    -    activation_index = _get_activation_index(activation)
    -
    +    # Note that we don't support weights with mini batching yet, so 
validate
    +    # this based on is_minibatch_enabled.
    +    weights = '1' if not weights or not weights.strip() else 
weights.strip()
         _validate_args(source_table, output_table, summary_table,
                        standardization_table, independent_varname,
                        dependent_varname, hidden_layer_sizes, optimizer_params,
    -                   is_classification, weights, warm_start, activation,
    -                   grouping_col)
    +                   warm_start, activation, grouping_col)
    +    is_minibatch_enabled = check_if_minibatch_enabled(source_table, 
independent_varname)
    +    _validate_params_based_on_minibatch(source_table, independent_varname,
    +                                        dependent_varname, weights,
    +                                        is_classification,
    +                                        is_minibatch_enabled)
    +    activation = _get_activation_function_name(activation)
    +    learning_rate_policy = _get_learning_rate_policy_name(
    +                                optimizer_params["learning_rate_policy"])
    +    activation_index = _get_activation_index(activation)
     
         reserved_cols = ['coeff', 'loss', 'n_iterations']
    +    grouping_col = grouping_col or ""
         grouping_str, grouping_col = get_grouping_col_str(schema_madlib, 'MLP',
                                                           reserved_cols,
                                                           source_table,
                                                           grouping_col)
    -    current_iteration = 1
    -    prev_state = None
    -    tolerance = optimizer_params["tolerance"]
    -    n_iterations = optimizer_params["n_iterations"]
    -    step_size_init = optimizer_params["learning_rate_init"]
    -    iterations_per_step = optimizer_params["iterations_per_step"]
    -    power = optimizer_params["power"]
    -    gamma = optimizer_params["gamma"]
    -    step_size = step_size_init
    -    n_tries = optimizer_params["n_tries"]
    -    # lambda is a reserved word in python
    -    lmbda = optimizer_params["lambda"]
    -    iterations_per_step = optimizer_params["iterations_per_step"]
    -    num_input_nodes = array_col_dimension(source_table,
    -                                          independent_varname)
    -    num_output_nodes = 0
    +    dependent_varname_backup = dependent_varname
    --- End diff --
    
    can we add a comment explaining why we need this backup variable ?


---

Reply via email to