reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537800172



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = 
unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = 
{0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS 
{self.original_model_output_tbl};
+                                    CREATE TABLE 
{self.original_model_output_tbl} AS

Review comment:
       I'd consider changing it if there's an argument for the RENAME being 
faster, but I'm not sure it would be... I imagine the VACUUM might take just as 
long.  Either way, there shouldn't be any motion across segments since the 
DISTRIBUTED BY is the same on both.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to