reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r537800172
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -196,72 +195,110 @@ class FitMultipleModel(): self.dist_key_mapping_valid, self.images_per_seg_valid = \ get_image_count_per_seg_for_minibatched_data_from_db( self.validation_table) - self.mst_weights_tbl = unique_string(desp='mst_weights') - self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule') + self.model_input_tbl = unique_string(desp='model_input') + self.schedule_tbl = unique_string(desp='schedule') - self.dist_keys = query_dist_keys(self.source_table, dist_key_col) - if len(self.msts) < len(self.dist_keys): + self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col) + DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys)) + self.max_dist_key = sorted(self.dist_keys)[-1] + DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys))) + DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key)) + self.extra_dist_keys = [] + + num_msts = len(self.msts) + num_dist_keys = len(self.dist_keys) + + if num_msts < num_dist_keys: self.msts_for_schedule = self.msts + [None] * \ - (len(self.dist_keys) - len(self.msts)) + (num_dist_keys - num_msts) else: self.msts_for_schedule = self.msts + if num_msts > num_dist_keys: + for i in range(num_msts - num_dist_keys): + self.extra_dist_keys.append(self.max_dist_key + 1 + i) + + DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys)) + DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys)) + random.shuffle(self.msts_for_schedule) - self.grand_schedule = self.generate_schedule(self.msts_for_schedule) - self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME - self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' - if self.warm_start: - self.create_model_output_table_warm_start() - else: - self.create_model_output_table() + # Comma-separated list of the mst_keys, including NULL's + # This will be used to pass the mst keys to the db as + # a sql ARRAY[] + self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\ + for mst in self.msts_for_schedule ] - self.weights_to_update_tbl = unique_string(desp='weights_to_update') - self.fit_multiple_model() + # List of all dist_keys, including any extra dist keys beyond + # the # segments we'll be training on--these represent the + # segments models will rest on while not training, which + # may overlap with the ones that will have training on them. + self.all_dist_keys = self.dist_keys + self.extra_dist_keys - # Update and cleanup metadata tables - self.insert_info_table() - self.create_model_summary_table() - if self.warm_start: - self.cleanup_for_warm_start() - reset_cuda_env(original_cuda_env) + self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME + self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' def fit_multiple_model(self): + self.init_schedule_tbl() + self.init_model_output_tbl() + self.init_model_info_tbl() + # WARNING: set orca off to prevent unwanted redistribution with OptimizerControl(False): self.start_training_time = datetime.datetime.now() self.metrics_elapsed_start_time = time.time() self.train_multiple_model() self.end_training_time = datetime.datetime.now() - def cleanup_for_warm_start(self): + # Update and cleanup metadata tables + self.insert_info_table() + self.create_model_summary_table() + self.write_final_model_output_tbl() + reset_cuda_env(self.original_cuda_env) + + def write_final_model_output_tbl(self): """ - 1. drop original model table + 1. drop original model table if exists 2. rename temp to original :return: """ - drop_query = "DROP TABLE IF EXISTS {}".format( - self.original_model_output_table) - plpy.execute(drop_query) - rename_table(self.schema_madlib, self.model_output_table, - self.original_model_output_table) + final_output_table_create_query = """ + DROP TABLE IF EXISTS {self.original_model_output_tbl}; + CREATE TABLE {self.original_model_output_tbl} AS Review comment: I'd consider changing it if there's an argument for the RENAME being faster, but I'm not sure it would be... I imagine the VACUUM might take just as long. Either way, there shouldn't be any motion across segments since the DISTRIBUTED BY is the same on both. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org