khannaekta commented on a change in pull request #511:
URL: https://github.com/apache/madlib/pull/511#discussion_r467212576



##########
File path: 
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
##########
@@ -175,9 +175,21 @@ pb=dill.dumps(test_custom_fn)
 return pb
 $$ language plpythonu;
 
+CREATE OR REPLACE FUNCTION custom_function_one_object()

Review comment:
       We can probably move this to the setup file 
`madlib_keras_custom_function.setup.sql_in` and call that file here.

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -294,6 +294,11 @@ class FitMultipleModel():
 
     def populate_object_map(self):
         builtin_losses = dir(losses)
+        builtin_metrics = dir(metrics)
+        builtin_metrics.append('accuracy')
+        builtin_metrics.append('acc')
+        builtin_metrics.append('crossentropy')
+        builtin_metrics.append('ce')

Review comment:
       We can probably move these to to a common file as these values are used 
by `MstLoaderInputValidator._validate_compile_and_fit_params()` and 
`get_custom_functions_list()`.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -407,7 +419,17 @@ def get_custom_functions_list(compile_params):
     """
     compile_dict = convert_string_of_args_to_dict(compile_params)
     builtin_losses = dir(losses)
+    builtin_metrics = dir(metrics)
+    builtin_metrics.append('accuracy')
+    builtin_metrics.append('acc')
+    builtin_metrics.append('crossentropy')
+    builtin_metrics.append('ce')

Review comment:
       define in a single file.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -325,8 +326,19 @@ def compile_model(model, compile_params, 
custom_function_map=None):
     optimizers = get_optimizers()
     (opt_name,final_args,compile_dict) = 
parse_and_validate_compile_params(compile_params)
     if custom_function_map is not None:
-        map=dill.loads(custom_function_map)
-        compile_dict['loss']=map[compile_dict['loss']]
+        local_map=dill.loads(custom_function_map)
+
+        compile_dict['loss']=local_map[compile_dict['loss']] \
+            if compile_dict['loss'] in local_map else compile_dict['loss']
+
+        new_metrics = []
+        for i in compile_dict['metrics']:

Review comment:
       Just a note, we currently support only 1 metrics, so we don't need to 
iterate here, the loop is fine too. Will leave it upto you, if you want to 
change it, else this is good too.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to