khannaekta commented on a change in pull request #466: DL: Modify multi-fit
warm start to accept non-matching mst&model tables
URL: https://github.com/apache/madlib/pull/466#discussion_r362944135
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -274,65 +274,66 @@ class FitMultipleModel():
plpy.execute(mst_insert_query)
def create_model_output_table(self):
- output_table_create_query = """
- CREATE TABLE {self.model_output_table}
- ({self.mst_key_col} INTEGER PRIMARY KEY,
- {self.model_weights_col} BYTEA,
- {self.model_arch_col} JSON)
- """.format(self=self)
- info_table_create_query = """
- CREATE TABLE {self.model_info_table}
- ({self.mst_key_col} INTEGER PRIMARY KEY,
- {self.model_id_col} INTEGER,
- {self.compile_params_col} TEXT,
- {self.fit_params_col} TEXT,
- model_type TEXT,
- model_size DOUBLE PRECISION,
- metrics_elapsed_time DOUBLE PRECISION[],
- metrics_type TEXT[],
- training_metrics_final DOUBLE PRECISION,
- training_loss_final DOUBLE PRECISION,
- training_metrics DOUBLE PRECISION[],
- training_loss DOUBLE PRECISION[],
- validation_metrics_final DOUBLE PRECISION,
- validation_loss_final DOUBLE PRECISION,
- validation_metrics DOUBLE PRECISION[],
- validation_loss DOUBLE PRECISION[])
- """.format(self=self)
-
- plpy.execute(output_table_create_query)
- plpy.execute(info_table_create_query)
+ warm_start_msts = []
+ if self.warm_start:
+ warm_start_msts = plpy.execute(
+ """ SELECT array_agg({0}) AS a FROM {1}
+ """.format(self.mst_key_col, self.model_output_table))[0]['a']
+ plpy.execute("TRUNCATE TABLE {0}".format(self.model_info_table))
+ else:
+ output_table_create_query = """
+ CREATE TABLE {self.model_output_table}
+ ({self.mst_key_col} INTEGER PRIMARY
KEY,
+ {self.model_weights_col} BYTEA,
+ {self.model_arch_col} JSON)
+ """.format(self=self)
+ info_table_create_query = """
+ CREATE TABLE {self.model_info_table}
+ ({self.mst_key_col} INTEGER PRIMARY KEY,
+ {self.model_id_col} INTEGER,
+ {self.compile_params_col} TEXT,
+ {self.fit_params_col} TEXT,
+ model_type TEXT,
+ model_size DOUBLE PRECISION,
+ metrics_elapsed_time DOUBLE PRECISION[],
+ metrics_type TEXT[],
+ training_metrics_final DOUBLE PRECISION,
+ training_loss_final DOUBLE PRECISION,
+ training_metrics DOUBLE PRECISION[],
+ training_loss DOUBLE PRECISION[],
+ validation_metrics_final DOUBLE
PRECISION,
+ validation_loss_final DOUBLE PRECISION,
+ validation_metrics DOUBLE PRECISION[],
+ validation_loss DOUBLE PRECISION[])
+ """.format(self=self)
+
+ plpy.execute(output_table_create_query)
+ plpy.execute(info_table_create_query)
for mst in self.msts:
model_arch, model_weights =
get_model_arch_weights(self.model_arch_table,
mst[self.model_id_col])
+
+
+ # If warm start is enabled, weigths from transfer learning cannot
be
Review comment:
typo `weigths`-> `weights`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services