Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/243#discussion_r175893520
--- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in ---
@@ -1457,3 +1660,85 @@ def mlp_predict_help(schema_madlib, message):
return """
No such option. Use "SELECT {schema_madlib}.mlp_predict()" for
help.
""".format(**args)
+
+
+def check_if_minibatch_enabled(source_table, independent_varname):
+ """
+ Function to validate if the source_table is converted to a format
that
+ can be used for mini-batching. It checks for the dimensionalities
of
+ the independent variable to determine the same.
+ """
+ query = """
+ SELECT array_upper({0}, 1) AS n_x,
+ array_upper({0}, 2) AS n_y,
+ array_upper({0}, 3) AS n_z
+ FROM {1}
+ LIMIT 1
+ """.format(independent_varname, source_table)
+ result = plpy.execute(query)
+
+ if not result:
+ plpy.error("MLP: Input table could be empty.")
+
+ has_x_dim, has_y_dim, has_z_dim = [bool(result[0][i])
+ for i in ('n_x', 'n_y', 'n_z')]
+ if not has_x_dim:
+ plpy.error("MLP: {0} is empty.".format(independent_varname))
+
+ # error out if >2d matrix
+ if has_z_dim:
+ plpy.error("MLP: Input table is not in the right format.")
+ return has_y_dim
+
+
+class MLPPreProcessor:
+ """
+ This class consumes and validates the pre-processed source table used
for
+ MLP mini-batch. This also populates values from the pre-processed
summary
+ table which is used by MLP mini-batch
+
+ """
+ # summary table columns names
+ DEPENDENT_VARNAME = "dependent_varname"
+ INDEPENDENT_VARNAME = "independent_varname"
+ GROUPING_COL = "grouping_cols"
+ CLASS_VALUES = "class_values"
+ MODEL_TYPE_CLASSIFICATION = "classification"
+ MODEL_TYPE_REGRESSION = "regression"
+
+ def __init__(self, source_table):
+ self.source_table = source_table
+ self.preprocessed_summary_dict = None
+ self.summary_table = add_postfix(self.source_table, "_summary")
+ self.std_table = add_postfix(self.source_table, "_standardization")
+
+ self._validate_and_set_preprocessed_summary()
+
+ def _validate_and_set_preprocessed_summary(self):
+ input_tbl_valid(self.source_table, 'MLP')
+
+ if not table_exists(self.summary_table) or not
table_exists(self.std_table):
+ plpy.error("Tables {0} and/or {1} do not exist. These tables
are"
+ " needed for using minibatch during
training.".format(
+
self.summary_table,
+
self.std_table))
+
+ query = "SELECT * FROM {0}".format(self.summary_table)
+ summary_table_columns = plpy.execute(query)
+ if not summary_table_columns or len(summary_table_columns) == 0:
+ plpy.error("No columns in table
{0}.".format(self.summary_table))
+ else:
+ summary_table_columns = summary_table_columns[0]
+
+ required_columns = (self.DEPENDENT_VARNAME,
self.INDEPENDENT_VARNAME,
--- End diff --
we also use `buffer_size` and `source_table` columns from the summary
table. Do we need to validate those as well or are these three enough ?
If we decide to assert all columns that we consume, we will have to keep
this assert in sync with how we use the summary dict which is easy to forget. I
don't have a better solution but just wanted to mention it.
---