njayaram2 commented on a change in pull request #378: DL: Update
madlib_keras_fit code to weight by images instead of buffers
URL: https://github.com/apache/madlib/pull/378#discussion_r279062306
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -82,12 +89,25 @@ def fit(schema_madlib, source_table, model,
dependent_varname,
# about making the fit function easier to read and maintain.
if is_platform_pg():
set_keras_session(use_gpu)
+ # Compute total images in dataset
+ total_images_per_seg = plpy.execute(
+ """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS total_images_per_seg
+ FROM {1}
+ """.format(dependent_varname, source_table))
+ seg_ids_train = "[]::integer[]"
gp_segment_id_col = -1
else:
+ # Compute total images on each segment
+ total_images_per_seg = plpy.execute(
+ """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS
total_images_per_seg
+ FROM {1}
+ GROUP BY gp_segment_id
+ """.format(dependent_varname, source_table))
+ seg_ids_train = [int(each_segment["gp_segment_id"])
+ for each_segment in total_images_per_seg]
Review comment:
Might be a good idea to refactor this out to a different function similar to
`get_rows_per_seg_from_db`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services