njayaram2 commented on a change in pull request #388:  DL: Add new param 
metrics_compute_frequency to madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/388#discussion_r283483312
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
 ##########
 @@ -173,134 +179,131 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
     # Run distributed training for specified number of iterations
     for i in range(num_iterations):
         start_iteration = time.time()
-        iteration_result = plpy.execute(run_training_iteration, 
[model_state])[0]['iteration_result']
+        iteration_result = plpy.execute(run_training_iteration,
+                                        [model_state])[0]['iteration_result']
         end_iteration = time.time()
         plpy.info("Time for iteration {0}: {1} sec".
                   format(i + 1, end_iteration - start_iteration))
         aggregate_runtime.append(datetime.datetime.now())
-        avg_loss, avg_accuracy, model_state = 
madlib_keras_serializer.deserialize_iteration_state(iteration_result)
+        avg_loss, avg_metric, model_state = madlib_keras_serializer.\
+            deserialize_iteration_state(iteration_result)
         plpy.info("Average loss after training iteration {0}: {1}".format(
             i + 1, avg_loss))
         plpy.info("Average accuracy after training iteration {0}: {1}".format(
-            i + 1, avg_accuracy))
-        if validation_set_provided:
-            _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
-                model_state, model_shapes)
-            master_model.set_weights(updated_weights)
-            start_val = time.time()
-            evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
-                                                           validation_table,
-                                                           dependent_varname,
-                                                           independent_varname,
-                                                           
compile_params_to_pass,
-                                                           model_arch, 
model_state,
-                                                           gpus_per_host,
-                                                           segments_per_host,
-                                                           seg_ids_val,
-                                                           rows_per_seg_val,
-                                                           gp_segment_id_col)
-            end_val = time.time()
-            plpy.info("Time for validation in iteration {0}: {1} sec". 
format(i + 1, end_val - start_val))
-            if len(evaluate_result) < 2:
-                plpy.error('Calling evaluate on validation data returned < 2 '
-                           'metrics. Expected metrics are loss and accuracy')
-            validation_loss = evaluate_result[0]
-            validation_accuracy = evaluate_result[1]
-            plpy.info("Validation set accuracy after iteration {0}: {1}".
-                      format(i + 1, validation_accuracy))
-            validation_aggregate_accuracy.append(validation_accuracy)
-            validation_aggregate_loss.append(validation_loss)
-        aggregate_loss.append(avg_loss)
-        aggregate_accuracy.append(avg_accuracy)
+            i + 1, avg_metric))
+
+        if should_compute_metrics_this_iter(i, metrics_compute_frequency,
+                                            num_iterations):
+            # TODO: Do we need this code?
+            # _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
+            #     model_state, model_shapes)
+            # master_model.set_weights(updated_weights)
+            # Compute loss/accuracy for training data.
+            # TODO: Uncomment this once JIRA MADLIB-1332 is merged to master
+            # compute_loss_and_metrics(
+            #     schema_madlib, source_table, dependent_varname,
+            #     independent_varname, compile_params_to_pass, model_arch,
+            #     model_state, gpus_per_host, segments_per_host, seg_ids_val,
+            #     rows_per_seg_val, gp_segment_id_col,
+            #     training_metrics, training_loss,
+            #     i, "Training")
+            metrics_iters.append(i)
+            if validation_set_provided:
+                # Compute loss/accuracy for validation data.
+                compute_loss_and_metrics(
+                    schema_madlib, validation_table, dependent_varname,
+                    independent_varname, compile_params_to_pass, model_arch,
+                    model_state, gpus_per_host, segments_per_host, seg_ids_val,
+                    rows_per_seg_val, gp_segment_id_col,
+                    validation_metrics, validation_loss,
+                    i, "Validation")
+        training_loss.append(avg_loss)
+        training_metrics.append(avg_metric)
 
     end_training_time = datetime.datetime.now()
 
-    final_validation_acc = None
-    if validation_aggregate_accuracy and len(validation_aggregate_accuracy) > 
0:
-        final_validation_acc = validation_aggregate_accuracy[-1]
-
-    final_validation_loss = None
-    if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
-        final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
     class_values, class_values_type = get_col_value_and_type(
         fit_validator.source_summary_table, CLASS_VALUES_COLNAME)
     norm_const, norm_const_type = get_col_value_and_type(
         fit_validator.source_summary_table, NORMALIZING_CONST_COLNAME)
     dep_vartype = plpy.execute("SELECT {0} AS dep FROM {1}".format(
         DEPENDENT_VARTYPE_COLNAME, 
fit_validator.source_summary_table))[0]['dep']
-    dependent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'dependent_varname', 
fit_validator.source_summary_table))[0]['dependent_varname'])
-    independent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'independent_varname', 
fit_validator.source_summary_table))[0]['independent_varname'])
+    # Quote_ident TEXT values to be inserted into the summary table
+    dependent_varname_in_source_table = plpy.execute("SELECT {0} FROM 
{1}".format(
 
 Review comment:
   I think it's better if we do this and a bunch of other summary table reads 
in one shot and consume it. Refactored the code to do that.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to