khannaekta commented on a change in pull request #516:
URL: https://github.com/apache/madlib/pull/516#discussion_r492345605
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var,
independent_var, dependent_var_shape,
images_per_seg)
is_last_row = agg_image_count == total_images
return_state = get_state_to_return(segment_model, is_last_row,
is_multiple_model,
- agg_image_count, total_images)
+ agg_image_count, total_images)
if is_last_row:
if is_final_iteration or is_multiple_model:
SD_STORE.clear_SD(SD)
clear_keras_session(sess)
return return_state
+def fit_multiple_transition(state, dependent_var, independent_var,
dependent_var_shape,
+ independent_var_shape, model_architecture,
+ compile_params, fit_params, dist_key,
dist_key_mapping,
+ current_seg_id, segments_per_host,
images_per_seg, use_gpus,
+ accessible_gpus_for_seg, prev_serialized_weights,
+ is_final_training_call, custom_function_map=None,
**kwargs):
+ """
+ This transition function is called when caching is called for
+ madlib_keras_fit_multiple_model().
+ The input params: dependent_var, independent_var are passed in
+ as None and dependent_var_shape, independent_var_shape as [0]
+ for all hops except the very firt hop
+ Some things to note in this function are:
+ - prev_serialized_weights can be passed in as None for the
+ very first hop and the final training call
+ - x_train, y_train and cache_set is cleared from SD for
+ final_training_call = TRUE
+ """
+ if not state:
+ agg_image_count = 0
+ else:
+ agg_image_count = float(state)
+
+ SD = kwargs['SD']
+ is_cache_set = 'cache_set' in SD
+
+ # Prepare the data
+ if is_cache_set:
+ if 'x_train' not in SD or 'y_train' not in SD:
+ plpy.error("cache not populated properly.")
+ total_images = None
+ is_last_row = True
+ else:
+ if not independent_var or not dependent_var:
+ return state
+ if 'x_train' not in SD:
+ SD['x_train'] = list()
+ SD['y_train'] = list()
+ agg_image_count += independent_var_shape[0]
+ total_images =
get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+ images_per_seg)
+ is_last_row = agg_image_count == total_images
+ if is_last_row:
+ SD['cache_set'] = True
+ x_train_current = np_array_float32(independent_var,
independent_var_shape)
+ y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+ SD['x_train'].append(x_train_current)
+ SD['y_train'].append(y_train_current)
+
+ # Passed in weights can be None. Irrespective of the weights, we want to
populate the cache for the very first hop.
+ # But if the weights are None, we do not want to set any model. So early
return in that case
+ if prev_serialized_weights is None:
+ if is_final_training_call:
+ del SD['x_train']
+ del SD['y_train']
+ del SD['cache_set']
+ return float(agg_image_count)
Review comment:
This function either returns a float(for all rows except the last row)
and the a list for the very last row.
For fit transition we return the state as returned from the wrapper function
`get_state_to_return()` which holds the same logic too.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]