kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r535759055
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
local_loss = compile_dict['loss'].lower() if 'loss' in
compile_dict else None
local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics'
in compile_dict else None
if local_loss and (local_loss not in [a.lower() for a in
builtin_losses]):
- custom_fn_names.append(local_loss)
- custom_fn_mst_idx.append(mst_idx)
+ custom_fn_names.add(local_loss)
+ custom_msts.append(mst)
if local_metric and (local_metric not in [a.lower() for a in
builtin_metrics]):
- custom_fn_names.append(local_metric)
- custom_fn_mst_idx.append(mst_idx)
-
- if len(custom_fn_names) > 0:
- # Pass only unique custom_fn_names to query from object table
- custom_fn_object_map =
query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
- for mst_idx in custom_fn_mst_idx:
- self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
- def create_mst_schedule_table(self, mst_row):
- mst_temp_query = """
- CREATE {self.unlogged_table} TABLE
{self.mst_current_schedule_tbl}
- ({self.model_id_col} INTEGER,
- {self.compile_params_col} VARCHAR,
- {self.fit_params_col} VARCHAR,
- {dist_key_col} INTEGER,
- {self.mst_key_col} INTEGER,
- {self.object_map_col} BYTEA)
- """.format(dist_key_col=dist_key_col, **locals())
- plpy.execute(mst_temp_query)
- for mst, dist_key in zip(mst_row, self.dist_keys):
- if mst:
- model_id = mst[self.model_id_col]
- compile_params = mst[self.compile_params_col]
- fit_params = mst[self.fit_params_col]
- mst_key = mst[self.mst_key_col]
- object_map = mst[self.object_map_col]
- else:
- model_id = "NULL"
- compile_params = "NULL"
- fit_params = "NULL"
- mst_key = "NULL"
- object_map = None
- mst_insert_query = plpy.prepare(
- """
- INSERT INTO {self.mst_current_schedule_tbl}
- VALUES ({model_id},
- $madlib${compile_params}$madlib$,
- $madlib${fit_params}$madlib$,
- {dist_key},
- {mst_key},
- $1)
- """.format(**locals()), ["BYTEA"])
- plpy.execute(mst_insert_query, [object_map])
-
- def create_model_output_table(self):
- output_table_create_query = """
- CREATE TABLE {self.model_output_table}
- ({self.mst_key_col} INTEGER PRIMARY KEY,
- {self.model_weights_col} BYTEA,
- {self.model_arch_col} JSON)
- """.format(self=self)
- plpy.execute(output_table_create_query)
- self.initialize_model_output_and_info()
+ custom_fn_names.add(local_metric)
+ custom_msts.append(mst)
+
+ self.custom_fn_object_map =
query_custom_functions_map(self.object_table, custom_fn_names)
+
+ for mst in custom_msts:
+ mst[self.object_map_col] = self.custom_fn_object_map
+
+ self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+ def init_schedule_tbl(self):
+ self.prev_dist_key_col = '__prev_dist_key__'
+ mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+ create_sched_query = """
+ CREATE TABLE {self.schedule_tbl} AS
+ WITH map AS
+ (SELECT
+ unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+ unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+ )
+ SELECT
+ map.{self.mst_key_col},
+ {self.model_id_col},
+ map.{self.dist_key_col} AS {self.prev_dist_key_col},
+ map.{self.dist_key_col}
+ FROM map LEFT JOIN {self.model_selection_table}
+ USING ({self.mst_key_col})
+ DISTRIBUTED BY ({self.dist_key_col})
+ """.format(self=self, mst_key_list=mst_key_list)
+ DEBUG.plpy.execute(create_sched_query)
+
+ def rotate_schedule_tbl(self):
+ if not hasattr(self, 'rotate_schedule_plan'):
+ self.next_schedule_tbl = unique_string('next_schedule')
+ rotate_schedule_tbl_query = """
+ CREATE TABLE {self.next_schedule_tbl} AS
+ SELECT
+ {self.mst_key_col},
+ {self.model_id_col},
+ {self.dist_key_col} AS {self.prev_dist_key_col},
+ COALESCE(
+ LEAD({self.dist_key_col})
+ OVER(ORDER BY {self.dist_key_col}),
+ FIRST_VALUE({self.dist_key_col})
+ OVER(ORDER BY {self.dist_key_col})
+ ) AS {self.dist_key_col}
+ FROM {self.schedule_tbl};
+ """.format(self=self)
+ self.rotate_schedule_tbl_plan =
plpy.prepare(rotate_schedule_tbl_query)
+
+ DEBUG.plpy.execute(self.rotate_schedule_tbl_plan)
+
+ self.truncate_and_drop(self.schedule_tbl)
+ plpy.execute("""
+ ALTER TABLE {self.next_schedule_tbl}
+ RENAME TO {self.schedule_tbl}
+ """.format(self=self))
- def create_model_output_table_warm_start(self):
+ def load_warm_start_weights(self):
"""
- For warm start, we need to copy the model output table to a temp table
- because we call truncate on the model output table while training.
- If the query gets aborted, we need to make sure that the user passed
- model output table can be recovered.
+ For warm start, we need to copy any rows of the model output
+ table provided by the user whose mst keys appear in the
+ supplied model selection table. We also copy over the
+ compile & fit params from the model_selection_table, and
+ the dist_key's from the schedule table.
"""
- plpy.execute("""
- CREATE TABLE {self.model_output_table} (
- LIKE {self.original_model_output_table} INCLUDING indexes);
- """.format(self=self))
+ load_warm_start_weights_query = """
+ INSERT INTO {self.model_output_tbl}
+ SELECT s.{self.mst_key_col},
+ o.{self.model_weights_col},
+ o.{self.model_arch_col},
+ m.{self.compile_params_col},
+ m.{self.fit_params_col},
+ NULL AS {self.object_map_col}, -- Fill in later
+ s.{self.dist_key_col}
+ FROM {self.schedule_tbl} s
+ JOIN {self.model_selection_table} m
+ USING ({self.mst_key_col})
+ JOIN {self.original_model_output_tbl} o
+ USING ({self.mst_key_col})
+ """.format(self=self)
+ DEBUG.plpy.execute(load_warm_start_weights_query)
- plpy.execute("""INSERT INTO {self.model_output_table}
- SELECT * FROM {self.original_model_output_table};
- """.format(self=self))
+ plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
- plpy.execute(""" DELETE FROM {self.model_output_table}
- WHERE {self.mst_key_col} NOT IN (
- SELECT {self.mst_key_col} FROM
{self.model_selection_table})
- """.format(self=self))
- self.warm_start_msts = plpy.execute(
- """ SELECT array_agg({0}) AS a FROM {1}
- """.format(self.mst_key_col, self.model_output_table))[0]['a']
- plpy.execute("DROP TABLE {0}".format(self.model_info_table))
- self.initialize_model_output_and_info()
-
- def initialize_model_output_and_info(self):
+ def load_xfer_learning_weights(self, warm_start=False):
+ """
+ Copy transfer learning weights from
+ model_arch table. Ignore models with
+ no xfer learning weights, these will
+ be generated by keras and added one at a
+ time later.
+ """
+ load_xfer_learning_weights_query = """
+ INSERT INTO {self.model_output_tbl}
+ SELECT s.{self.mst_key_col},
+ a.{self.model_weights_col},
+ a.{self.model_arch_col},
+ m.{self.compile_params_col},
+ m.{self.fit_params_col},
+ NULL AS {self.object_map_col}, -- Fill in later
+ s.{self.dist_key_col}
+ FROM {self.schedule_tbl} s
+ JOIN {self.model_selection_table} m
+ USING ({self.mst_key_col})
+ JOIN {self.model_arch_table} a
+ ON m.{self.model_id_col} = a.{self.model_id_col}
+ WHERE a.{self.model_weights_col} IS NOT NULL;
+ """.format(self=self)
+ DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+ def init_model_output_tbl(self):
+ DEBUG.start_timing('init_model_output_and_info')
+
+ output_table_create_query = """
+ CREATE TABLE {self.model_output_tbl}
+ ({self.mst_key_col} INTEGER,
+ {self.model_weights_col} BYTEA,
+ {self.model_arch_col} JSON,
+ {self.compile_params_col} TEXT,
+ {self.fit_params_col} TEXT,
+ {self.object_map_col} BYTEA,
+ {self.dist_key_col} INTEGER,
+ PRIMARY KEY ({self.dist_key_col},
{self.mst_key_col})
+ )
+ DISTRIBUTED BY ({self.dist_key_col})
+ """.format(self=self)
+ plpy.execute(output_table_create_query)
+
+ if self.warm_start:
+ self.load_warm_start_weights()
+ else: # Note: We only support xfer learning when warm_start=False
+ self.load_xfer_learning_weights()
+
+ res = DEBUG.plpy.execute("""
+ SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+ """.format(self=self))
+
+ if res:
+ initialized_msts = set([ row['mst_keys'] for row in res ])
+ else:
+ initialized_msts = set()
+
+ DEBUG.plpy.info("Pre-initialized mst keys:
{}".format(initialized_msts))
Review comment:
We should try to reduce the frequency of `DEBUG.plpy.info` code lines
since they can be feature specific. Having a lot of these code lines might make
the code slightly harder to read.
Any developer working on a feature can add their own `DEBUG.plpy.info` as
needed.
I think we can remove this debug plpy.info statement. We can add it back for
debugging as and when needed
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]