reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537797983
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
self.dist_key_mapping_valid, self.images_per_seg_valid = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.validation_table)
- self.mst_weights_tbl = unique_string(desp='mst_weights')
- self.mst_current_schedule_tbl =
unique_string(desp='mst_current_schedule')
+ self.model_input_tbl = unique_string(desp='model_input')
+ self.schedule_tbl = unique_string(desp='schedule')
- self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
- if len(self.msts) < len(self.dist_keys):
+ self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+ DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+ self.max_dist_key = sorted(self.dist_keys)[-1]
+ DEBUG.plpy.info("sorted_dist_keys =
{0}".format(sorted(self.dist_keys)))
+ DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+ self.extra_dist_keys = []
+
+ num_msts = len(self.msts)
+ num_dist_keys = len(self.dist_keys)
+
+ if num_msts < num_dist_keys:
self.msts_for_schedule = self.msts + [None] * \
- (len(self.dist_keys) - len(self.msts))
+ (num_dist_keys - num_msts)
else:
self.msts_for_schedule = self.msts
+ if num_msts > num_dist_keys:
+ for i in range(num_msts - num_dist_keys):
+ self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+ DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+ DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
random.shuffle(self.msts_for_schedule)
- self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
- self.gp_segment_id_col = '0' if is_platform_pg() else
GP_SEGMENT_ID_COLNAME
- self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
- if self.warm_start:
- self.create_model_output_table_warm_start()
- else:
- self.create_model_output_table()
+ # Comma-separated list of the mst_keys, including NULL's
+ # This will be used to pass the mst keys to the db as
+ # a sql ARRAY[]
+ self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+ for mst in self.msts_for_schedule ]
- self.weights_to_update_tbl = unique_string(desp='weights_to_update')
- self.fit_multiple_model()
+ # List of all dist_keys, including any extra dist keys beyond
+ # the # segments we'll be training on--these represent the
+ # segments models will rest on while not training, which
+ # may overlap with the ones that will have training on them.
+ self.all_dist_keys = self.dist_keys + self.extra_dist_keys
- # Update and cleanup metadata tables
- self.insert_info_table()
- self.create_model_summary_table()
- if self.warm_start:
- self.cleanup_for_warm_start()
- reset_cuda_env(original_cuda_env)
+ self.gp_segment_id_col = '0' if is_platform_pg() else
GP_SEGMENT_ID_COLNAME
+ self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
def fit_multiple_model(self):
+ self.init_schedule_tbl()
+ self.init_model_output_tbl()
+ self.init_model_info_tbl()
+
# WARNING: set orca off to prevent unwanted redistribution
with OptimizerControl(False):
self.start_training_time = datetime.datetime.now()
self.metrics_elapsed_start_time = time.time()
self.train_multiple_model()
self.end_training_time = datetime.datetime.now()
- def cleanup_for_warm_start(self):
+ # Update and cleanup metadata tables
+ self.insert_info_table()
+ self.create_model_summary_table()
+ self.write_final_model_output_tbl()
+ reset_cuda_env(self.original_cuda_env)
+
+ def write_final_model_output_tbl(self):
"""
- 1. drop original model table
+ 1. drop original model table if exists
2. rename temp to original
:return:
"""
- drop_query = "DROP TABLE IF EXISTS {}".format(
- self.original_model_output_table)
- plpy.execute(drop_query)
- rename_table(self.schema_madlib, self.model_output_table,
- self.original_model_output_table)
+ final_output_table_create_query = """
+ DROP TABLE IF EXISTS
{self.original_model_output_tbl};
+ CREATE TABLE
{self.original_model_output_tbl} AS
Review comment:
I did try that, but it gets a bit messy.
There are a 3 differences between the schemas:
1. The temporary model table has 3 extra columns: compile_params,
fit_params, and object_map columns
2. The temporary model has a PRIMARY KEY constraint (a combination of
__dist_key__, mst_key), which means it also has a SEQUENCE relation that exists
in the same schema with a similar name.
3. The temporary table is an unlogged table, while presumably the user
expects the output to be an ordinary logged table.
The first we could fix by DROP'ing the 3 extra columns after the ALTER TABLE
RENAME. But even then, there's still a minor issue with it: the user is left
with a table which is probably fragmented on disk in a weird way... with the
extra columns probably still there but marked as deleted. So to clean that up,
we'd probably want to run VACUUM or something on it, to have it consolidate the
table a bit so it's not taking up a larger footprint than it would.
The second one isn't too bad, but also requires extra work: if you do the
table rename, it doesn't rename the SEQUENCE, and gets errors when you try to
access it. So you either have to do another RENAME on the SEQUENCE, or just
DROP it.
The third one I see as the most problematic, unless there is an easy way to
change it from UNLOGGED to LOGGED. Seems like an undesirable outcome the user
might not expect.
I didn't get as far as looking up how to fix the 3rd one, since at that
point, it just seemed simpler and cleaner to DROP the temp table and create a
nice fresh table as output that we know isn't going to have any issues.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]