Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/243#discussion_r175890376
--- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in ---
@@ -1457,3 +1660,85 @@ def mlp_predict_help(schema_madlib, message):
return """
No such option. Use "SELECT {schema_madlib}.mlp_predict()" for
help.
""".format(**args)
+
+
+def check_if_minibatch_enabled(source_table, independent_varname):
+ """
+ Function to validate if the source_table is converted to a format
that
+ can be used for mini-batching. It checks for the dimensionalities
of
+ the independent variable to determine the same.
+ """
+ query = """
+ SELECT array_upper({0}, 1) AS n_x,
+ array_upper({0}, 2) AS n_y,
+ array_upper({0}, 3) AS n_z
+ FROM {1}
+ LIMIT 1
+ """.format(independent_varname, source_table)
+ result = plpy.execute(query)
+
+ if not result:
+ plpy.error("MLP: Input table could be empty.")
+
+ has_x_dim, has_y_dim, has_z_dim = [bool(result[0][i])
+ for i in ('n_x', 'n_y', 'n_z')]
+ if not has_x_dim:
+ plpy.error("MLP: {0} is empty.".format(independent_varname))
+
+ # error out if >2d matrix
+ if has_z_dim:
+ plpy.error("MLP: Input table is not in the right format.")
+ return has_y_dim
+
+
+class MLPPreProcessor:
+ """
+ This class consumes and validates the pre-processed source table used
for
+ MLP mini-batch. This also populates values from the pre-processed
summary
+ table which is used by MLP mini-batch
+
+ """
+ # summary table columns names
+ DEPENDENT_VARNAME = "dependent_varname"
+ INDEPENDENT_VARNAME = "independent_varname"
+ GROUPING_COL = "grouping_cols"
+ CLASS_VALUES = "class_values"
+ MODEL_TYPE_CLASSIFICATION = "classification"
+ MODEL_TYPE_REGRESSION = "regression"
+
+ def __init__(self, source_table):
+ self.source_table = source_table
+ self.preprocessed_summary_dict = None
+ self.summary_table = add_postfix(self.source_table, "_summary")
+ self.std_table = add_postfix(self.source_table, "_standardization")
+
+ self._validate_and_set_preprocessed_summary()
+
+ def _validate_and_set_preprocessed_summary(self):
+ input_tbl_valid(self.source_table, 'MLP')
--- End diff --
we don't really need to validate the source table here since it would
already be validated by the `_validate_args` function.
---