njayaram2 commented on a change in pull request #388: DL: Add new param
metrics_compute_frequency to madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/388#discussion_r283447550
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -313,6 +316,61 @@ def fit(schema_madlib, source_table,
model,model_arch_table,
#TODO add a unit test for this in a future PR
reset_cuda_env(original_cuda_env)
+def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
+ independent_varname, compile_params, model_arch,
+ model_state, gpus_per_host, segments_per_host,
+ seg_ids_val, rows_per_seg_val,
+ gp_segment_id_col, metrics_list, loss_list,
+ curr_iter, dataset_name):
+ """
+ Compute the loss and metric using a given model (model_state) on the
+ given dataset (table.)
+ """
+ start_val = time.time()
+ evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
+ table,
+ dependent_varname,
+ independent_varname,
+ compile_params,
+ model_arch, model_state,
+ gpus_per_host,
+ segments_per_host,
+ seg_ids_val,
+ rows_per_seg_val,
+ gp_segment_id_col)
+ end_val = time.time()
+ plpy.info("Time for evaluation in iteration {0}: {1} sec.". format(
+ curr_iter + 1, end_val - start_val))
+ if len(evaluate_result) < 2:
+ plpy.error('Calling evaluate on table {0} returned < 2 '
+ 'metrics. Expected both loss and a metric.'.format(
+ table))
+ loss = evaluate_result[0]
+ metric = evaluate_result[1]
+ plpy.info("{0} set metric after iteration {1}: {2}.".
+ format(dataset_name, curr_iter + 1, metric))
+ plpy.info("{0} set loss after iteration {1}: {2}.".
+ format(dataset_name, curr_iter + 1, loss))
+ metrics_list.append(metric)
+ loss_list.append(loss)
+
+def should_compute_metrics_this_iter(
Review comment:
Will change it in the fit function too, it will be inconsistent otherwise.
For instance, all the prints we have do an `i+1` while reporting numbers, and
`compute_loss_and_metrics` also assumes zero indexed `i` value. Will instead
run the for loop in `range(1, num_iterations+1)` and make it consistent
everywhere.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services