reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537838544



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in 
compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' 
in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in 
builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in 
builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = 
query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE 
{self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = 
query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = 
plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM 
{self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, 
{self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       Re: 1, I'll think about it.  Any suggestions?
   
   2. Actually it's only for transfer learning, that's actually the important 
part of the name.  Notice the IS NOT NULL part of the query which makes sure 
it's only copying the transfer learning weights.
   
   The basic logic for initializing the model output table is this:
   
   1.  Bulk load any warm start weights from previous model output table
   2.  Bulk load any transfer learning weights from model arch table
   3.  Iterate over all remaining mst_keys that haven't already been loaded, 
and initialize with random weights, writing them one by one with INSERT's to 
the model output table.
   
   We could drop the IS NOT NULL condition in 2, but since there are no 
weights, it would just be copying the tiny stuff (compile params, fit params, 
etc.)... so bulk loading doesn't really help there anyway.  And it would make 
it more difficult later, since we'd have to INSERT the weights in a different 
way for rows that already have some of the columns initialized but not others.  
(Or, just write over them, in which case nothing useful was accomplished by 
copying them the first time.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to