This is an automated email from the ASF dual-hosted git repository. okislal pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push: new 5f10bc8 DL: Modify multi-fit warm start to accept non-matching mst&model tables 5f10bc8 is described below commit 5f10bc8e72e88986cd109745dddec672fdaa1d84 Author: Orhan Kislal <okis...@apache.org> AuthorDate: Tue Jan 7 19:36:34 2020 -0500 DL: Modify multi-fit warm start to accept non-matching mst&model tables JIRA: MADLIB-1400 #resolve The warm start enforced that the model table had to have a tuple for each mst_key in the mst table for warm start. This commit relaxes this requirement so that users can add as well as substract mst keys throughtout their AutoML progress. Closes #466 --- .../madlib_keras_fit_multiple_model.py_in | 70 ++++++++++++++-------- .../deep_learning/madlib_keras_validator.py_in | 7 --- .../test/madlib_keras_transfer_learning.sql_in | 24 +++++--- 3 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in index 5ce555a..273321e 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in @@ -162,8 +162,8 @@ class FitMultipleModel(): random.shuffle(self.msts_for_schedule) self.grand_schedule = self.generate_schedule(self.msts_for_schedule) self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME - if not self.warm_start: - self.create_model_output_table() + + self.create_model_output_table() self.weights_to_update_tbl = unique_string(desp='weights_to_update') self.fit_multiple_model() reset_cuda_env(original_cuda_env) @@ -274,12 +274,26 @@ class FitMultipleModel(): plpy.execute(mst_insert_query) def create_model_output_table(self): - output_table_create_query = """ - CREATE TABLE {self.model_output_table} - ({self.mst_key_col} INTEGER PRIMARY KEY, - {self.model_weights_col} BYTEA, - {self.model_arch_col} JSON) - """.format(self=self) + warm_start_msts = [] + if self.warm_start: + plpy.execute(""" DELETE FROM {self.model_output_table} + WHERE {self.mst_key_col} NOT IN ( + SELECT {self.mst_key_col} FROM {self.model_selection_table}) + """.format(self=self)) + warm_start_msts = plpy.execute( + """ SELECT array_agg({0}) AS a FROM {1} + """.format(self.mst_key_col, self.model_output_table))[0]['a'] + plpy.execute("DROP TABLE {0}".format(self.model_info_table)) + + else: + output_table_create_query = """ + CREATE TABLE {self.model_output_table} + ({self.mst_key_col} INTEGER PRIMARY KEY, + {self.model_weights_col} BYTEA, + {self.model_arch_col} JSON) + """.format(self=self) + plpy.execute(output_table_create_query) + info_table_create_query = """ CREATE TABLE {self.model_info_table} ({self.mst_key_col} INTEGER PRIMARY KEY, @@ -300,39 +314,32 @@ class FitMultipleModel(): validation_loss DOUBLE PRECISION[]) """.format(self=self) - plpy.execute(output_table_create_query) plpy.execute(info_table_create_query) for mst in self.msts: model_arch, model_weights = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col]) + + + # If warm start is enabled, weights from transfer learning cannot be + # used, even if a particular model doesn't have warm start weigths. + if self.warm_start: + model_weights = None + serialized_weights = get_initial_weights(self.model_output_table, model_arch, model_weights, - False, + mst['mst_key'] in warm_start_msts, self.use_gpus, self.accessible_gpus_for_seg ) - model = model_from_json(model_arch) - serialized_state = model_weights if model_weights else \ - madlib_keras_serializer.serialize_nd_weights(model.get_weights()) - model_size = sys.getsizeof(serialized_weights) / 1024.0 + metrics_list = get_metrics_from_compile_param( mst[self.compile_params_col]) is_metrics_specified = True if metrics_list else False metrics_type = 'ARRAY{0}'.format( metrics_list) if is_metrics_specified else 'NULL' - output_table_insert_query = """ - INSERT INTO {self.model_output_table}( - {self.mst_key_col}, {self.model_weights_col}, - {self.model_arch_col}) - VALUES ({mst_key}, $1, $2) - """.format(self=self, - mst_key=mst[self.mst_key_col]) - output_table_insert_query_prepared = plpy.prepare( - output_table_insert_query, ["bytea", "json"]) - plpy.execute(output_table_insert_query_prepared, [ - serialized_state, model_arch]) + info_table_insert_query = """ INSERT INTO {self.model_info_table}({self.mst_key_col}, {self.model_id_col}, {self.compile_params_col}, @@ -352,6 +359,19 @@ class FitMultipleModel(): metrics_type=metrics_type) plpy.execute(info_table_insert_query) + if not mst['mst_key'] in warm_start_msts: + output_table_insert_query = """ + INSERT INTO {self.model_output_table}( + {self.mst_key_col}, {self.model_weights_col}, + {self.model_arch_col}) + VALUES ({mst_key}, $1, $2) + """.format(self=self, + mst_key=mst[self.mst_key_col]) + output_table_insert_query_prepared = plpy.prepare( + output_table_insert_query, ["bytea", "json"]) + plpy.execute(output_table_insert_query_prepared, [ + serialized_weights, model_arch]) + def create_model_summary_table(self): if self.warm_start: plpy.execute("DROP TABLE {0}".format(self.model_summary_table)) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in index ad14087..37a2e25 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in @@ -437,13 +437,6 @@ class FitMultipleInputValidator(FitCommonValidator): accessible_gpus_for_seg, self.module_name) - if warm_start: - mst_count = plpy.execute("SELECT count(*) FROM {0}".format(model_selection_table))[0]['count'] - warm_count = plpy.execute("SELECT count(*) FROM {0}".format(output_model_table))[0]['count'] - - _assert(mst_count <= warm_count, - "{self.module_name} error: Model table and mst table do not match".format(self=self)) - class MstLoaderInputValidator(): def __init__(self, model_arch_table, diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in index 3c970a5..d17ea20 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in @@ -224,20 +224,18 @@ SELECT madlib_keras_fit_multiple_model( FALSE, NULL, 1, TRUE); - SELECT assert( - abs(first.training_loss_final-second.training_loss_final) < 1e-6, - 'The loss should not change for mst_key 4 since it has been removed from mst_table') -FROM iris_model_first_run AS first, iris_multiple_model_info AS second -WHERE first.mst_key = second.mst_key AND second.mst_key = 4; + 4 NOT IN (SELECT mst_key FROM iris_multiple_model), + 'mst_key 4 should not be in the model table since it has been removed from mst_table'); -INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params, - 'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1; +SELECT assert( + 4 NOT IN (SELECT mst_key FROM iris_multiple_model_info), + 'mst_key 4 should not be in the info table since it has been removed from mst_table'); INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params, 'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1; -SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model( +SELECT madlib_keras_fit_multiple_model( 'iris_data_packed', 'iris_multiple_model', 'mst_table', @@ -245,7 +243,15 @@ SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model( FALSE, NULL, 1, TRUE -- warm_start -);$TRAP$) = 1, 'Warm start with extra mst keys should fail.'); +); + +SELECT assert( + 5 IN (SELECT mst_key FROM iris_multiple_model), + 'mst_key 5 should be in the model table since it has been added to mst_table'); + +SELECT assert( + 5 IN (SELECT mst_key FROM iris_multiple_model_info), + 'mst_key 5 should be in the info table since it has been added to mst_table'); -- Transfer learning tests