khannaekta commented on a change in pull request #516:
URL: https://github.com/apache/madlib/pull/516#discussion_r492339498
##########
File path:
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -580,6 +741,44 @@ FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy',
optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+-- Test with caching when number of configs(4) larger than number of
segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary,
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table_4row',
+ 3,
+ FALSE, NULL, NULL, NULL, NULL, NULL,
+ TRUE
+);
Review comment:
Yes, this will still pass as it asserts on the number of models trained.
Eventhough for a cluster size > 3 segments, it might not test the use case
intended in this specific test but it would not fail.
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var,
independent_var, dependent_var_shape,
images_per_seg)
is_last_row = agg_image_count == total_images
return_state = get_state_to_return(segment_model, is_last_row,
is_multiple_model,
- agg_image_count, total_images)
+ agg_image_count, total_images)
if is_last_row:
if is_final_iteration or is_multiple_model:
SD_STORE.clear_SD(SD)
clear_keras_session(sess)
return return_state
+def fit_multiple_transition(state, dependent_var, independent_var,
dependent_var_shape,
+ independent_var_shape, model_architecture,
+ compile_params, fit_params, dist_key,
dist_key_mapping,
+ current_seg_id, segments_per_host,
images_per_seg, use_gpus,
+ accessible_gpus_for_seg, prev_serialized_weights,
+ is_final_training_call, custom_function_map=None,
**kwargs):
+ """
+ This transition function is called when caching is called for
+ madlib_keras_fit_multiple_model().
+ The input params: dependent_var, independent_var are passed in
+ as None and dependent_var_shape, independent_var_shape as [0]
+ for all hops except the very firt hop
+ Some things to note in this function are:
+ - prev_serialized_weights can be passed in as None for the
+ very first hop and the final training call
+ - x_train, y_train and cache_set is cleared from SD for
+ final_training_call = TRUE
+ """
+ if not state:
+ agg_image_count = 0
+ else:
+ agg_image_count = float(state)
+
+ SD = kwargs['SD']
+ is_cache_set = 'cache_set' in SD
+
+ # Prepare the data
+ if is_cache_set:
+ if 'x_train' not in SD or 'y_train' not in SD:
+ plpy.error("cache not populated properly.")
+ total_images = None
+ is_last_row = True
+ else:
+ if not independent_var or not dependent_var:
+ return state
+ if 'x_train' not in SD:
+ SD['x_train'] = list()
+ SD['y_train'] = list()
+ agg_image_count += independent_var_shape[0]
+ total_images =
get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+ images_per_seg)
+ is_last_row = agg_image_count == total_images
+ if is_last_row:
+ SD['cache_set'] = True
+ x_train_current = np_array_float32(independent_var,
independent_var_shape)
+ y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+ SD['x_train'].append(x_train_current)
+ SD['y_train'].append(y_train_current)
+
+ # Passed in weights can be None. Irrespective of the weights, we want to
populate the cache for the very first hop.
+ # But if the weights are None, we do not want to set any model. So early
return in that case
+ if prev_serialized_weights is None:
+ if is_final_training_call:
+ del SD['x_train']
+ del SD['y_train']
+ del SD['cache_set']
+ return float(agg_image_count)
Review comment:
This function either returns a float(for all rows except the last row)
and the a list for the very last row.
For fit transition we return the state as returned from the wrapper function
`get_state_to_return()` which holds the same logic too.
##########
File path:
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -375,6 +375,39 @@ SELECT assert(
'Keras Fit Multiple Output Summary Validation failed when user passes
in 1-hot encoded label vector. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary,
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_one_hot_encoded_packed',
+ 'iris_multiple_model',
+ 'mst_table_4row',
+ 3,
+ FALSE, NULL, NULL, NULL, NULL, NULL,
+ TRUE
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ model_info = 'iris_multiple_model_info' AND
+ source_table = 'iris_data_one_hot_encoded_packed' AND
+ model = 'iris_multiple_model' AND
+ model_selection_table = 'mst_table_4row' AND
+ object_table IS NULL AND
+ dependent_varname = 'class_one_hot_encoded' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ dependent_vartype = 'integer[]' AND
+ num_classes = NULL AND
+ class_values = NULL AND
+ normalizing_const = 1 AND
+ metrics_iters = ARRAY[3],
+ 'Keras Fit Multiple Output Summary Validation failed when user passes
in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
Review comment:
Updated
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
##########
@@ -1506,13 +1507,17 @@ CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.fit_transition_multiple_model(
segments_per_host INTEGER,
images_per_seg INTEGER[],
use_gpus BOOLEAN,
- accessible_gpus_for_seg INTEGER[],
+ accessible_gpus_for_seg INTEGER[],
prev_serialized_weights BYTEA,
- is_final_iteration BOOLEAN,
+ is_final_training_call BOOLEAN,
+ use_caching BOOLEAN,
custom_function_map BYTEA
) RETURNS BYTEA AS $$
PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
- return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+ if use_caching:
+ return madlib_keras.fit_multiple_transition(**globals())
Review comment:
Updated to fit_multiple_transition_caching
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]