Github user kaknikhil commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/243#discussion_r175893520
  
    --- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in ---
    @@ -1457,3 +1660,85 @@ def mlp_predict_help(schema_madlib, message):
         return """
             No such option. Use "SELECT {schema_madlib}.mlp_predict()" for 
help.
         """.format(**args)
    +
    +
    +def check_if_minibatch_enabled(source_table, independent_varname):
    +    """
    +        Function to validate if the source_table is converted to a format 
that
    +        can be used for mini-batching. It checks for the dimensionalities 
of
    +        the independent variable to determine the same.
    +    """
    +    query = """
    +        SELECT array_upper({0}, 1) AS n_x,
    +               array_upper({0}, 2) AS n_y,
    +               array_upper({0}, 3) AS n_z
    +        FROM {1}
    +        LIMIT 1
    +    """.format(independent_varname, source_table)
    +    result = plpy.execute(query)
    +
    +    if not result:
    +        plpy.error("MLP: Input table could be empty.")
    +
    +    has_x_dim, has_y_dim, has_z_dim = [bool(result[0][i])
    +                                       for i in ('n_x', 'n_y', 'n_z')]
    +    if not has_x_dim:
    +        plpy.error("MLP: {0} is empty.".format(independent_varname))
    +
    +    # error out if >2d matrix
    +    if has_z_dim:
    +        plpy.error("MLP: Input table is not in the right format.")
    +    return has_y_dim
    +
    +
    +class MLPPreProcessor:
    +    """
    +    This class consumes and validates the pre-processed source table used 
for
    +    MLP mini-batch. This also populates values from the pre-processed 
summary
    +    table which is used by MLP mini-batch
    +
    +    """
    +    # summary table columns names
    +    DEPENDENT_VARNAME = "dependent_varname"
    +    INDEPENDENT_VARNAME = "independent_varname"
    +    GROUPING_COL = "grouping_cols"
    +    CLASS_VALUES = "class_values"
    +    MODEL_TYPE_CLASSIFICATION = "classification"
    +    MODEL_TYPE_REGRESSION = "regression"
    +
    +    def __init__(self, source_table):
    +        self.source_table = source_table
    +        self.preprocessed_summary_dict = None
    +        self.summary_table = add_postfix(self.source_table, "_summary")
    +        self.std_table = add_postfix(self.source_table, "_standardization")
    +
    +        self._validate_and_set_preprocessed_summary()
    +
    +    def _validate_and_set_preprocessed_summary(self):
    +        input_tbl_valid(self.source_table, 'MLP')
    +
    +        if not table_exists(self.summary_table) or not 
table_exists(self.std_table):
    +            plpy.error("Tables {0} and/or {1} do not exist. These tables 
are"
    +                       " needed for using minibatch during 
training.".format(
    +                                                             
self.summary_table,
    +                                                             
self.std_table))
    +
    +        query = "SELECT * FROM {0}".format(self.summary_table)
    +        summary_table_columns = plpy.execute(query)
    +        if not summary_table_columns or len(summary_table_columns) == 0:
    +            plpy.error("No columns in table 
{0}.".format(self.summary_table))
    +        else:
    +            summary_table_columns = summary_table_columns[0]
    +
    +        required_columns = (self.DEPENDENT_VARNAME, 
self.INDEPENDENT_VARNAME,
    --- End diff --
    
    we also use `buffer_size` and `source_table` columns from the summary 
table. Do we need to validate those as well or are these three enough ?
    If we decide to assert all columns that we consume, we will have to keep 
this assert in sync with how we use the summary dict which is easy to forget. I 
don't have a better solution but just wanted to mention it. 


---

Reply via email to