reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537797983



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = 
unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = 
{0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS 
{self.original_model_output_tbl};
+                                    CREATE TABLE 
{self.original_model_output_tbl} AS

Review comment:
       I did try that, but it gets a bit messy.
   
   There are a 3 differences between the schemas:
      1. The temporary model table has 3 extra columns:  compile_params, 
fit_params, and object_map columns
      2. The temporary model has a PRIMARY KEY constraint (a combination of 
__dist_key__, mst_key), which means it also has a SEQUENCE relation that exists 
in the same schema with a similar name.
      3. The temporary table is an unlogged table, while presumably the user 
expects the output to be an ordinary logged table.
   
   The first we could fix by DROP'ing the 3 extra columns after the ALTER TABLE 
RENAME.  But even then, there's still a minor issue with it:  the user is left 
with a table which is probably fragmented on disk in a weird way... with the 
extra columns probably still there but marked as deleted.  So to clean that up, 
we'd probably want to run VACUUM or something on it, to have it consolidate the 
table a bit so it's not taking up a larger footprint than it would.
   
   The second one isn't too bad, but also requires extra work:  if you do the 
table rename, it doesn't rename the SEQUENCE, and gets errors when you try to 
access it.  So you either have to do another RENAME on the SEQUENCE, or just 
DROP it.
   
   The third one I see as the most problematic, unless there is an easy way to 
change it from UNLOGGED to LOGGED.  Seems like an undesirable outcome the user 
might not expect.
   
   I didn't get as far as looking up how to fix the 3rd one, since at that 
point, it just seemed simpler and cleaner to DROP the temp table and create a 
nice fresh table as output that we know isn't going to have any issues.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to