reductionista commented on a change in pull request #516:
URL: https://github.com/apache/madlib/pull/516#discussion_r489805349



##########
File path: 
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -375,6 +375,39 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed when user passes 
in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+       'iris_data_one_hot_encoded_packed',
+       'iris_multiple_model',
+       'mst_table_4row',
+       3,
+       FALSE, NULL, NULL, NULL, NULL, NULL,
+       TRUE
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        validation_table is NULL AND
+        model_info = 'iris_multiple_model_info' AND
+        source_table = 'iris_data_one_hot_encoded_packed' AND
+        model = 'iris_multiple_model' AND
+        model_selection_table = 'mst_table_4row' AND
+        object_table IS NULL AND
+        dependent_varname = 'class_one_hot_encoded' AND
+        independent_varname = 'attributes' AND
+        madlib_version is NOT NULL AND
+        num_iterations = 3 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        dependent_vartype = 'integer[]' AND
+        num_classes = NULL AND
+        class_values = NULL AND
+        normalizing_const = 1 AND
+        metrics_iters = ARRAY[3],
+        'Keras Fit Multiple Output Summary Validation failed when user passes 
in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+

Review comment:
       Many of these tests look like they have a lot in common with the 
non-caching tests.  It will be much easier to maintain if we make a common 
function and call it once for caching=True and once for caching=False.

##########
File path: 
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -580,6 +741,44 @@ FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', 
optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
 
+-- Test with caching when number of configs(4) larger than number of 
segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+       'iris_data_packed',
+       'iris_multiple_model',
+       'mst_table_4row',
+       3,
+       FALSE, NULL, NULL, NULL, NULL, NULL,
+       TRUE
+);

Review comment:
       Will these tests pass if the cluster has more than 3 segments (or less)?
   
   We don't know what platform future MADlib contributors might want to test 
this on, or what sort of pipelines it will need to pass in.

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
##########
@@ -1506,13 +1507,17 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.fit_transition_multiple_model(
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     use_gpus                   BOOLEAN,
-    accessible_gpus_for_seg               INTEGER[],
+    accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN,
+    is_final_training_call     BOOLEAN,
+    use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    if use_caching:
+        return madlib_keras.fit_multiple_transition(**globals())

Review comment:
       I notice the SQL function `fit_transition_multiple_model` calls the 
python function `fit_multiple_transition` if caching is enabled, otherwise 
`fit_transition`.  I think we could improve the names here to indicate what's 
going on.
   
   Maybe rename the SQL function to `fit_multiple_transition` and then the new 
python function can be `fit_multiple_transition_caching`?  (Then if we ever add 
a caching version of regular fit it could be `fit_transition_caching`.)

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var, 
independent_var, dependent_var_shape,
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, 
is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
     return return_state
 
+def fit_multiple_transition(state, dependent_var, independent_var, 
dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, 
dist_key_mapping,
+                             current_seg_id, segments_per_host, 
images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, 
**kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very firt hop
+    Some things to note in this function are:
+    - prev_serialized_weights can be passed in as None for the
+      very first hop and the final training call
+    - x_train, y_train and cache_set is cleared from SD for
+      final_training_call = TRUE
+    """
+    if not state:
+        agg_image_count = 0
+    else:
+        agg_image_count = float(state)
+
+    SD = kwargs['SD']
+    is_cache_set = 'cache_set' in SD
+
+    # Prepare the data
+    if is_cache_set:
+        if 'x_train' not in SD or 'y_train' not in SD:
+            plpy.error("cache not populated properly.")
+        total_images = None
+        is_last_row = True
+    else:
+        if not independent_var or not dependent_var:
+            return state
+        if 'x_train' not in SD:
+            SD['x_train'] = list()
+            SD['y_train'] = list()
+        agg_image_count += independent_var_shape[0]
+        total_images = 
get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+                                                          images_per_seg)
+        is_last_row = agg_image_count == total_images
+        if is_last_row:
+            SD['cache_set'] = True
+        x_train_current = np_array_float32(independent_var, 
independent_var_shape)
+        y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+        SD['x_train'].append(x_train_current)
+        SD['y_train'].append(y_train_current)
+
+    # Passed in weights can be None. Irrespective of the weights, we want to 
populate the cache for the very first hop.
+    # But if the weights are None, we do not want to set any model. So early 
return in that case
+    if prev_serialized_weights is None:
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+        return float(agg_image_count)

Review comment:
       I thought this function was supposed to return the state (a list), why 
is it returning a float?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var, 
independent_var, dependent_var_shape,
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, 
is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
     return return_state
 
+def fit_multiple_transition(state, dependent_var, independent_var, 
dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, 
dist_key_mapping,
+                             current_seg_id, segments_per_host, 
images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, 
**kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very firt hop

Review comment:
       minor typo




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to