reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r537797983
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -196,72 +195,110 @@ class FitMultipleModel(): self.dist_key_mapping_valid, self.images_per_seg_valid = \ get_image_count_per_seg_for_minibatched_data_from_db( self.validation_table) - self.mst_weights_tbl = unique_string(desp='mst_weights') - self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule') + self.model_input_tbl = unique_string(desp='model_input') + self.schedule_tbl = unique_string(desp='schedule') - self.dist_keys = query_dist_keys(self.source_table, dist_key_col) - if len(self.msts) < len(self.dist_keys): + self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col) + DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys)) + self.max_dist_key = sorted(self.dist_keys)[-1] + DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys))) + DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key)) + self.extra_dist_keys = [] + + num_msts = len(self.msts) + num_dist_keys = len(self.dist_keys) + + if num_msts < num_dist_keys: self.msts_for_schedule = self.msts + [None] * \ - (len(self.dist_keys) - len(self.msts)) + (num_dist_keys - num_msts) else: self.msts_for_schedule = self.msts + if num_msts > num_dist_keys: + for i in range(num_msts - num_dist_keys): + self.extra_dist_keys.append(self.max_dist_key + 1 + i) + + DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys)) + DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys)) + random.shuffle(self.msts_for_schedule) - self.grand_schedule = self.generate_schedule(self.msts_for_schedule) - self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME - self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' - if self.warm_start: - self.create_model_output_table_warm_start() - else: - self.create_model_output_table() + # Comma-separated list of the mst_keys, including NULL's + # This will be used to pass the mst keys to the db as + # a sql ARRAY[] + self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\ + for mst in self.msts_for_schedule ] - self.weights_to_update_tbl = unique_string(desp='weights_to_update') - self.fit_multiple_model() + # List of all dist_keys, including any extra dist keys beyond + # the # segments we'll be training on--these represent the + # segments models will rest on while not training, which + # may overlap with the ones that will have training on them. + self.all_dist_keys = self.dist_keys + self.extra_dist_keys - # Update and cleanup metadata tables - self.insert_info_table() - self.create_model_summary_table() - if self.warm_start: - self.cleanup_for_warm_start() - reset_cuda_env(original_cuda_env) + self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME + self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' def fit_multiple_model(self): + self.init_schedule_tbl() + self.init_model_output_tbl() + self.init_model_info_tbl() + # WARNING: set orca off to prevent unwanted redistribution with OptimizerControl(False): self.start_training_time = datetime.datetime.now() self.metrics_elapsed_start_time = time.time() self.train_multiple_model() self.end_training_time = datetime.datetime.now() - def cleanup_for_warm_start(self): + # Update and cleanup metadata tables + self.insert_info_table() + self.create_model_summary_table() + self.write_final_model_output_tbl() + reset_cuda_env(self.original_cuda_env) + + def write_final_model_output_tbl(self): """ - 1. drop original model table + 1. drop original model table if exists 2. rename temp to original :return: """ - drop_query = "DROP TABLE IF EXISTS {}".format( - self.original_model_output_table) - plpy.execute(drop_query) - rename_table(self.schema_madlib, self.model_output_table, - self.original_model_output_table) + final_output_table_create_query = """ + DROP TABLE IF EXISTS {self.original_model_output_tbl}; + CREATE TABLE {self.original_model_output_tbl} AS Review comment: I did try that, but it gets a bit messy. There are a 3 differences between the schemas: 1. The temporary model table has 3 extra columns: compile_params, fit_params, and object_map columns 2. The temporary model has a PRIMARY KEY constraint (a combination of __dist_key__, mst_key), which means it also has a SEQUENCE relation that exists in the same schema with a similar name. 3. The temporary table is an unlogged table, while presumably the user expects the output to be an ordinary logged table. The first we could fix by DROP'ing the 3 extra columns after the ALTER TABLE RENAME. But even then, there's still a minor issue with it: the user is left with a table which is probably fragmented on disk in a weird way... with the extra columns probably still there but marked as deleted. So to clean that up, we'd probably want to run VACUUM or something on it, to have it consolidate the table a bit so it's not taking up a larger footprint than it would. The second one isn't too bad, but also requires extra work: if you do the table rename, it doesn't rename the SEQUENCE, and gets errors when you try to access it. So you either have to do another RENAME on the SEQUENCE, or just DROP it. The third one I see as the most problematic, unless there is an easy way to change it from UNLOGGED to LOGGED. Seems like an undesirable outcome the user might not expect. I didn't get as far as looking up how to fix the 3rd one, since at that point, it just seemed simpler and cleaner to DROP the temp table and create a nice fresh table as output that we know isn't going to have any issues. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org