khannaekta commented on a change in pull request #466: DL: Modify multi-fit 
warm start to accept non-matching mst&model tables
URL: https://github.com/apache/madlib/pull/466#discussion_r362944135
 
 

 ##########
 File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
 ##########
 @@ -274,65 +274,66 @@ class FitMultipleModel():
             plpy.execute(mst_insert_query)
 
     def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        info_table_create_query = """
-                                  CREATE TABLE {self.model_info_table}
-                                  ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                   {self.model_id_col} INTEGER,
-                                   {self.compile_params_col} TEXT,
-                                   {self.fit_params_col} TEXT,
-                                   model_type TEXT,
-                                   model_size DOUBLE PRECISION,
-                                   metrics_elapsed_time DOUBLE PRECISION[],
-                                   metrics_type TEXT[],
-                                   training_metrics_final DOUBLE PRECISION,
-                                   training_loss_final DOUBLE PRECISION,
-                                   training_metrics DOUBLE PRECISION[],
-                                   training_loss DOUBLE PRECISION[],
-                                   validation_metrics_final DOUBLE PRECISION,
-                                   validation_loss_final DOUBLE PRECISION,
-                                   validation_metrics DOUBLE PRECISION[],
-                                   validation_loss DOUBLE PRECISION[])
-                               """.format(self=self)
-
-        plpy.execute(output_table_create_query)
-        plpy.execute(info_table_create_query)
+        warm_start_msts = []
+        if self.warm_start:
+            warm_start_msts = plpy.execute(
+                """ SELECT array_agg({0}) AS a FROM {1}
+                """.format(self.mst_key_col, self.model_output_table))[0]['a']
+            plpy.execute("TRUNCATE TABLE {0}".format(self.model_info_table))
+        else:
+            output_table_create_query = """
+                                        CREATE TABLE {self.model_output_table}
+                                        ({self.mst_key_col} INTEGER PRIMARY 
KEY,
+                                         {self.model_weights_col} BYTEA,
+                                         {self.model_arch_col} JSON)
+                                        """.format(self=self)
+            info_table_create_query = """
+                                      CREATE TABLE {self.model_info_table}
+                                      ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                       {self.model_id_col} INTEGER,
+                                       {self.compile_params_col} TEXT,
+                                       {self.fit_params_col} TEXT,
+                                       model_type TEXT,
+                                       model_size DOUBLE PRECISION,
+                                       metrics_elapsed_time DOUBLE PRECISION[],
+                                       metrics_type TEXT[],
+                                       training_metrics_final DOUBLE PRECISION,
+                                       training_loss_final DOUBLE PRECISION,
+                                       training_metrics DOUBLE PRECISION[],
+                                       training_loss DOUBLE PRECISION[],
+                                       validation_metrics_final DOUBLE 
PRECISION,
+                                       validation_loss_final DOUBLE PRECISION,
+                                       validation_metrics DOUBLE PRECISION[],
+                                       validation_loss DOUBLE PRECISION[])
+                                   """.format(self=self)
+
+            plpy.execute(output_table_create_query)
+            plpy.execute(info_table_create_query)
         for mst in self.msts:
             model_arch, model_weights = 
get_model_arch_weights(self.model_arch_table,
                                                                
mst[self.model_id_col])
+
+
+            # If warm start is enabled, weigths from transfer learning cannot 
be
 
 Review comment:
   typo `weigths`-> `weights`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to