reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r538651733
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -337,183 +376,308 @@ class FitMultipleModel(): local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): - custom_fn_names.append(local_loss) - custom_fn_mst_idx.append(mst_idx) + custom_fn_names.add(local_loss) + custom_msts.append(mst) if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): - custom_fn_names.append(local_metric) - custom_fn_mst_idx.append(mst_idx) - - if len(custom_fn_names) > 0: - # Pass only unique custom_fn_names to query from object table - custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names))) - for mst_idx in custom_fn_mst_idx: - self.msts[mst_idx][self.object_map_col] = custom_fn_object_map - - def create_mst_schedule_table(self, mst_row): - mst_temp_query = """ - CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl} - ({self.model_id_col} INTEGER, - {self.compile_params_col} VARCHAR, - {self.fit_params_col} VARCHAR, - {dist_key_col} INTEGER, - {self.mst_key_col} INTEGER, - {self.object_map_col} BYTEA) - """.format(dist_key_col=dist_key_col, **locals()) - plpy.execute(mst_temp_query) - for mst, dist_key in zip(mst_row, self.dist_keys): - if mst: - model_id = mst[self.model_id_col] - compile_params = mst[self.compile_params_col] - fit_params = mst[self.fit_params_col] - mst_key = mst[self.mst_key_col] - object_map = mst[self.object_map_col] - else: - model_id = "NULL" - compile_params = "NULL" - fit_params = "NULL" - mst_key = "NULL" - object_map = None - mst_insert_query = plpy.prepare( - """ - INSERT INTO {self.mst_current_schedule_tbl} - VALUES ({model_id}, - $madlib${compile_params}$madlib$, - $madlib${fit_params}$madlib$, - {dist_key}, - {mst_key}, - $1) - """.format(**locals()), ["BYTEA"]) - plpy.execute(mst_insert_query, [object_map]) - - def create_model_output_table(self): - output_table_create_query = """ - CREATE TABLE {self.model_output_table} - ({self.mst_key_col} INTEGER PRIMARY KEY, - {self.model_weights_col} BYTEA, - {self.model_arch_col} JSON) - """.format(self=self) - plpy.execute(output_table_create_query) - self.initialize_model_output_and_info() + custom_fn_names.add(local_metric) + custom_msts.append(mst) + + self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names) + + for mst in custom_msts: + mst[self.object_map_col] = self.custom_fn_object_map + + self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts } + + def init_schedule_tbl(self): + self.prev_dist_key_col = '__prev_dist_key__' + mst_key_list = '[' + ','.join(self.all_mst_keys) + ']' + + create_sched_query = """ + CREATE TABLE {self.schedule_tbl} AS + WITH map AS + (SELECT + unnest(ARRAY{mst_key_list}) {self.mst_key_col}, + unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col} + ) + SELECT + map.{self.mst_key_col}, + {self.model_id_col}, + map.{self.dist_key_col} AS {self.prev_dist_key_col}, + map.{self.dist_key_col} + FROM map LEFT JOIN {self.model_selection_table} + USING ({self.mst_key_col}) + DISTRIBUTED BY ({self.dist_key_col}) Review comment: We would if we wanted the weights to stay fixed... but the purpose of this query is to redistribute the weights from the previous distribution to the next distribution. The output has to be distributed by __dist_key__ since that's what will be used as the JOIN key with the source table's __dist_key__ in the next query. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org