Github user kaknikhil commented on a diff in the pull request: https://github.com/apache/madlib/pull/243#discussion_r175893520 --- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in --- @@ -1457,3 +1660,85 @@ def mlp_predict_help(schema_madlib, message): return """ No such option. Use "SELECT {schema_madlib}.mlp_predict()" for help. """.format(**args) + + +def check_if_minibatch_enabled(source_table, independent_varname): + """ + Function to validate if the source_table is converted to a format that + can be used for mini-batching. It checks for the dimensionalities of + the independent variable to determine the same. + """ + query = """ + SELECT array_upper({0}, 1) AS n_x, + array_upper({0}, 2) AS n_y, + array_upper({0}, 3) AS n_z + FROM {1} + LIMIT 1 + """.format(independent_varname, source_table) + result = plpy.execute(query) + + if not result: + plpy.error("MLP: Input table could be empty.") + + has_x_dim, has_y_dim, has_z_dim = [bool(result[0][i]) + for i in ('n_x', 'n_y', 'n_z')] + if not has_x_dim: + plpy.error("MLP: {0} is empty.".format(independent_varname)) + + # error out if >2d matrix + if has_z_dim: + plpy.error("MLP: Input table is not in the right format.") + return has_y_dim + + +class MLPPreProcessor: + """ + This class consumes and validates the pre-processed source table used for + MLP mini-batch. This also populates values from the pre-processed summary + table which is used by MLP mini-batch + + """ + # summary table columns names + DEPENDENT_VARNAME = "dependent_varname" + INDEPENDENT_VARNAME = "independent_varname" + GROUPING_COL = "grouping_cols" + CLASS_VALUES = "class_values" + MODEL_TYPE_CLASSIFICATION = "classification" + MODEL_TYPE_REGRESSION = "regression" + + def __init__(self, source_table): + self.source_table = source_table + self.preprocessed_summary_dict = None + self.summary_table = add_postfix(self.source_table, "_summary") + self.std_table = add_postfix(self.source_table, "_standardization") + + self._validate_and_set_preprocessed_summary() + + def _validate_and_set_preprocessed_summary(self): + input_tbl_valid(self.source_table, 'MLP') + + if not table_exists(self.summary_table) or not table_exists(self.std_table): + plpy.error("Tables {0} and/or {1} do not exist. These tables are" + " needed for using minibatch during training.".format( + self.summary_table, + self.std_table)) + + query = "SELECT * FROM {0}".format(self.summary_table) + summary_table_columns = plpy.execute(query) + if not summary_table_columns or len(summary_table_columns) == 0: + plpy.error("No columns in table {0}.".format(self.summary_table)) + else: + summary_table_columns = summary_table_columns[0] + + required_columns = (self.DEPENDENT_VARNAME, self.INDEPENDENT_VARNAME, --- End diff -- we also use `buffer_size` and `source_table` columns from the summary table. Do we need to validate those as well or are these three enough ? If we decide to assert all columns that we consume, we will have to keep this assert in sync with how we use the summary dict which is easy to forget. I don't have a better solution but just wanted to mention it.
---