kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560525168
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -663,17 +717,20 @@ def get_state_to_return(segment_model, is_last_row,
is_multiple_model, agg_image
:param is_last_row: boolean to indicate if last row for that hop
:param is_multiple_model: boolean
:param agg_image_count: aggregated image count per hop
- :param total_images: total images per segment
+ :param total_images: total images per segment (only used for
madlib_keras_fit() )
:return:
"""
- if is_last_row:
- updated_model_weights = segment_model.get_weights()
- if is_multiple_model:
+ if is_multiple_model:
+ if is_last_row:
+ updated_model_weights = segment_model.get_weights()
new_state =
madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
else:
- updated_model_weights = [total_images * w for w in
updated_model_weights]
- new_state =
madlib_keras_serializer.serialize_state_with_nd_weights(
- agg_image_count, updated_model_weights)
+ new_state = None
+ elif is_last_row:
Review comment:
yeah I think we should get rid of it for fit as well but you are right,
better to do that in a future PR
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]