reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r538674096
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -337,183 +376,308 @@ class FitMultipleModel(): local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): - custom_fn_names.append(local_loss) - custom_fn_mst_idx.append(mst_idx) + custom_fn_names.add(local_loss) + custom_msts.append(mst) if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): - custom_fn_names.append(local_metric) - custom_fn_mst_idx.append(mst_idx) - - if len(custom_fn_names) > 0: - # Pass only unique custom_fn_names to query from object table - custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names))) - for mst_idx in custom_fn_mst_idx: - self.msts[mst_idx][self.object_map_col] = custom_fn_object_map - - def create_mst_schedule_table(self, mst_row): - mst_temp_query = """ - CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl} - ({self.model_id_col} INTEGER, - {self.compile_params_col} VARCHAR, - {self.fit_params_col} VARCHAR, - {dist_key_col} INTEGER, - {self.mst_key_col} INTEGER, - {self.object_map_col} BYTEA) - """.format(dist_key_col=dist_key_col, **locals()) - plpy.execute(mst_temp_query) - for mst, dist_key in zip(mst_row, self.dist_keys): - if mst: - model_id = mst[self.model_id_col] - compile_params = mst[self.compile_params_col] - fit_params = mst[self.fit_params_col] - mst_key = mst[self.mst_key_col] - object_map = mst[self.object_map_col] - else: - model_id = "NULL" - compile_params = "NULL" - fit_params = "NULL" - mst_key = "NULL" - object_map = None - mst_insert_query = plpy.prepare( - """ - INSERT INTO {self.mst_current_schedule_tbl} - VALUES ({model_id}, - $madlib${compile_params}$madlib$, - $madlib${fit_params}$madlib$, - {dist_key}, - {mst_key}, - $1) - """.format(**locals()), ["BYTEA"]) - plpy.execute(mst_insert_query, [object_map]) - - def create_model_output_table(self): - output_table_create_query = """ - CREATE TABLE {self.model_output_table} - ({self.mst_key_col} INTEGER PRIMARY KEY, - {self.model_weights_col} BYTEA, - {self.model_arch_col} JSON) - """.format(self=self) - plpy.execute(output_table_create_query) - self.initialize_model_output_and_info() + custom_fn_names.add(local_metric) + custom_msts.append(mst) + + self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names) + + for mst in custom_msts: + mst[self.object_map_col] = self.custom_fn_object_map + + self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts } + + def init_schedule_tbl(self): + self.prev_dist_key_col = '__prev_dist_key__' + mst_key_list = '[' + ','.join(self.all_mst_keys) + ']' + + create_sched_query = """ + CREATE TABLE {self.schedule_tbl} AS + WITH map AS + (SELECT + unnest(ARRAY{mst_key_list}) {self.mst_key_col}, + unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col} + ) + SELECT + map.{self.mst_key_col}, + {self.model_id_col}, + map.{self.dist_key_col} AS {self.prev_dist_key_col}, + map.{self.dist_key_col} + FROM map LEFT JOIN {self.model_selection_table} + USING ({self.mst_key_col}) + DISTRIBUTED BY ({self.dist_key_col}) Review comment: Oh, I responded at first thinking this was the hop query. The distribution of this table shouldn't matter much, since there are no weights in it... but yes, I think you're right, I'll change it to prev dist key. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org