This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push: new 374145f DL: Add model_arch column to model data table 374145f is described below commit 374145f2214cb36b5e328f9d98779e91d45c0eaf Author: Ekta Khanna <ekha...@pivotal.io> AuthorDate: Fri May 17 15:35:13 2019 -0700 DL: Add model_arch column to model data table JIRA: MADLIB-1347 This commit adds the 'model_arch' column to the model data table so that predict and evaluate can directly get the model_arch from the model data table instead of the model arch table. We have made changes to the predict code in this commit but the changes for evaluate will be done in a future PR. closes #394 Co-authored-by: Nikhil Kak <n...@pivotal.io> --- .../modules/deep_learning/madlib_keras.py_in | 7 +++-- .../deep_learning/madlib_keras_validator.py_in | 11 ++++++-- .../deep_learning/predict_input_params.py_in | 32 ++++------------------ .../modules/deep_learning/test/madlib_keras.sql_in | 17 +++++++++++- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in index 82384d4..087838f 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in @@ -320,9 +320,10 @@ def fit(schema_madlib, source_table, model,model_arch_table, description, aggregate_runtime, class_values]) create_output_table = plpy.prepare(""" - CREATE TABLE {0} AS - SELECT $1 as model_data""".format(model), ["bytea"]) - plpy.execute(create_output_table, [model_state]) + CREATE TABLE {0} AS SELECT + $1 as model_data, + $2 as {1}""".format(model, Format.MODEL_ARCH), ["bytea", "json"]) + plpy.execute(create_output_table, [model_state, model_arch]) if is_platform_pg(): clear_keras_session() diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in index 9210550..00426dd 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in @@ -18,6 +18,7 @@ # under the License. import plpy +from keras_model_arch_table import Format from madlib_keras_helper import CLASS_VALUES_COLNAME from madlib_keras_helper import COMPILE_PARAMS_COLNAME from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME @@ -107,11 +108,17 @@ class PredictInputValidator: def _validate_model_data_col(self): _assert(is_var_valid(self.model_table, MODEL_DATA_COLNAME), - "{module_name} error: invalid model_data " - "('{model_data}') in model table ({table}).".format( + "{module_name} error: column '{model_data}' " + "does not exist in model table '{table}'.".format( module_name=self.module_name, model_data=MODEL_DATA_COLNAME, table=self.model_table)) + _assert(is_var_valid(self.model_table, Format.MODEL_ARCH), + "{module_name} error: column '{model_arch}' " + "does not exist in model table '{table}'.".format( + module_name=self.module_name, + model_arch=Format.MODEL_ARCH, + table=self.model_table)) def _validate_test_tbl_cols(self): _assert(is_var_valid(self.test_table, self.independent_varname), diff --git a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in index aba6dce..7f88020 100644 --- a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in +++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in @@ -25,8 +25,6 @@ from utilities.validate_args import input_tbl_valid from madlib_keras_helper import CLASS_VALUES_COLNAME from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME -from madlib_keras_helper import MODEL_ARCH_ID_COLNAME -from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME from madlib_keras_helper import MODEL_DATA_COLNAME from madlib_keras_helper import NORMALIZING_CONST_COLNAME @@ -36,28 +34,11 @@ class PredictParamsProcessor: self.model_table = model_table self.model_summary_table = add_postfix(self.model_table, '_summary') input_tbl_valid(self.model_summary_table, self.module_name) - self.model_summary_dict = self._get_model_summary_dict() - self.model_arch_dict = self._get_model_arch_dict() + self.model_summary_dict = self._get_dict_for_table(self.model_summary_table) + self.model_arch_dict = self._get_dict_for_table(self.model_table) - def _get_model_summary_dict(self): - return plpy.execute("SELECT * FROM {0}".format( - self.model_summary_table))[0] - - def _get_model_arch_dict(self): - model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_COLNAME] - model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_COLNAME] - input_tbl_valid(model_arch_table, self.module_name) - model_arch_query = """ - SELECT {0} - FROM {1} - WHERE {2} = {3} - """.format(Format.MODEL_ARCH, model_arch_table, Format.MODEL_ID, - model_arch_id) - query_result = plpy.execute(model_arch_query) - if not query_result or len(query_result) == 0: - plpy.error("{0}: No model arch found in table {1} with id {2}".format( - self.module_name, model_arch_table, model_arch_id)) - return query_result[0] + def _get_dict_for_table(self, table_name): + return plpy.execute("SELECT * FROM {0}".format(table_name), 1)[0] def get_class_values(self): return self.model_summary_dict[CLASS_VALUES_COLNAME] @@ -72,10 +53,7 @@ class PredictParamsProcessor: return self.model_arch_dict[Format.MODEL_ARCH] def get_model_data(self): - return plpy.execute(""" - SELECT {0} FROM {1} - """.format(MODEL_DATA_COLNAME, self.model_table) - )[0][MODEL_DATA_COLNAME] + return self.model_arch_dict[MODEL_DATA_COLNAME] def get_normalizing_const(self): return self.model_summary_dict[NORMALIZING_CONST_COLNAME] diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in index b5aaa6d..93ca513 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in @@ -151,7 +151,10 @@ SELECT assert( 'Keras model output Summary Validation failed. Actual:' || __to_char(summary)) FROM (SELECT * FROM keras_saved_out_summary) summary; -SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') FROM (SELECT * FROM keras_saved_out) k; +SELECT assert( + model_data IS NOT NULL AND + model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k)) +FROM (SELECT * FROM keras_saved_out) k; -- Verify number of iterations for which metrics and loss are computed @@ -743,3 +746,15 @@ SELECT assert(trap_error($TRAP$madlib_keras_predict( 'prob', 0);$TRAP$) = 1, 'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should have failed.'); + +-- Test model_arch is retrieved from model data table and not model architecture +DROP TABLE IF EXISTS model_arch; +DROP TABLE IF EXISTS cifar10_predict; +SELECT madlib_keras_predict( + 'keras_saved_out', + 'cifar_10_sample_test_shape', + 'id', + 'x', + 'cifar10_predict', + 'prob', + 0);