This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit fe42e7f5ec2fe1c1d5cc069dd01929c4131ac4d8
Author: Ekta Khanna <ekha...@vmware.com>
AuthorDate: Wed Jan 27 16:30:29 2021 -0800

    DL: Fix misc bugs
    
    JIRA: MADLIB-1464
    
    1. When validating for the validation table, we were passing the wrong
    table name to the validate_input_shape function.
    
    2. Add not supported error message for Multiple dependent and
    independent variables for fit_multiple
    
    3. PredictBYOM: Uncomment code and test for validating
    class_values(validate_class_values)
    
    4. Add error message for the case when fit and fit_multiple are called
    with an old version of preprocessed data.
    
    Co-authored-by: Ekta Khanna <ekha...@vmware.com>
---
 .../deep_learning/madlib_keras_predict.py_in       |  4 +-
 .../deep_learning/madlib_keras_validator.py_in     | 18 +++++---
 .../test/unit_tests/test_madlib_keras.py_in        | 53 +++++++++++++++++-----
 3 files changed, 56 insertions(+), 19 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 0e5b1b9..d23d765 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -337,8 +337,8 @@ class PredictBYOM(BasePredict):
         # are traversed in order. It won't work for multi-io and prone to 
breaking
         # in the regular case.
 
-        # InputValidator.validate_class_values(
-        #     self.module_name, self.class_values, self.pred_type, 
self.model_arch)
+        InputValidator.validate_class_values(
+            self.module_name, self.class_values, self.pred_type, 
self.model_arch)
         InputValidator.validate_input_shape(
             self.test_table, self.independent_varname,
             get_input_shape(self.model_arch), 1)
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 21eff15..439d9d9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -289,6 +289,11 @@ class FitCommonValidator(object):
 
         self.dependent_varname = self.src_summary_dict['dependent_varname']
         self.independent_varname = self.src_summary_dict['independent_varname']
+        if not isinstance(self.dependent_varname, list) or \
+                not isinstance(self.independent_varname, list):
+            #TODO improve error message
+            plpy.error("Input table '{0}' has not been preprocessed properly. "
+                       "Please run input preprocessor 
again.".format(self.source_table))
         self.dep_shape_cols = [add_postfix(i, "_shape") for i in 
self.dependent_varname]
         self.ind_shape_cols = [add_postfix(i, "_shape") for i in 
self.independent_varname]
 
@@ -406,7 +411,7 @@ class FitCommonValidator(object):
                                input_shape, 2, True)
         if self.validation_table:
             InputValidator.validate_input_shape(
-                self.validation_table,  self.independent_varname,
+                self.validation_table,  self.val_ind_var,
                 input_shape, 2, True)
 
 
@@ -459,11 +464,12 @@ class FitMultipleInputValidator(FitCommonValidator):
                                                         use_gpus,
                                                         
accessible_gpus_for_seg,
                                                         self.module_name,
-                                                        self.object_table,
-                                                        val_dep_var,
-                                                        val_ind_var)
-        self.output_model_info_table = add_postfix(output_model_table,
-                                                   '_info')
+                                                        self.object_table)
+        _assert(len(self.dependent_varname) == 1
+                or len(self.independent_varname) == 1,
+                "Multiple dependent and independent variables not supported "
+                "for madlib_keras_fit_multiple_model!")
+        self.output_model_info_table = add_postfix(output_model_table, '_info')
 
         if warm_start:
             input_tbl_valid(self.output_model_info_table, self.module_name)
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 928b753..5ef4517 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -907,14 +907,14 @@ class MadlibKerasPredictBYOMTestCase(unittest.TestCase):
         self.assertIn('invalid_pred_type', str(error.exception))
 
         # The validation for this test has been disabled
-        # with self.assertRaises(plpy.PLPYException) as error:
-        #     self.module.PredictBYOM('schema_madlib', 'model_arch_table',
-        #                              'model_id', 'test_table', 'id_col',
-        #                              'independent_varname', 'output_table',
-        #                              self.pred_type, self.use_gpus,
-        #                              ["foo", "bar", "baaz"], 
self.normalizing_const,
-        #                              self.dependent_count)
-        # self.assertIn('class values', str(error.exception).lower())
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                     'model_id', 'test_table', 'id_col',
+                                     'independent_varname', 'output_table',
+                                     self.pred_type, self.use_gpus,
+                                     ["foo", "bar", "baaz"], 
self.normalizing_const,
+                                     self.dependent_count)
+        self.assertIn('class values', str(error.exception).lower())
 
         with self.assertRaises(plpy.PLPYException) as error:
             self.module.PredictBYOM('schema_madlib', 'model_arch_table',
@@ -1313,6 +1313,37 @@ class 
MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
             'module_name', None)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
+    def test_validator_dep_indep_type_not_array(self):
+        # only dep is not array
+        self.subject.FitCommonValidator.get_source_summary_table_dict = \
+            Mock(return_value={'dependent_varname':'a',
+                               'independent_varname':['b']})
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.FitCommonValidator(
+                'test_table', 'val_table', 'model_table', 5, None, False, 
False, [0],
+                'module_name', None)
+        self.assertIn('not been preprocessed properly', str(error.exception))
+
+        # only indep is not array
+        self.subject.FitCommonValidator.get_source_summary_table_dict = \
+            Mock(return_value={'dependent_varname':['a'],
+                               'independent_varname':'b'})
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.FitCommonValidator(
+                'test_table', 'val_table', 'model_table', 5, None, False, 
False, [0],
+                'module_name', None)
+        self.assertIn('not been preprocessed properly', str(error.exception))
+
+        # both indep and dep are not arrays
+        self.subject.FitCommonValidator.get_source_summary_table_dict = \
+            Mock(return_value={'dependent_varname':'a',
+                               'independent_varname':'b'})
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.FitCommonValidator(
+                'test_table', 'val_table', 'model_table', 5, None, False, 
False, [0],
+                'module_name', None)
+        self.assertIn('not been preprocessed properly', str(error.exception))
+
 
 class InputValidatorTestCase(unittest.TestCase):
     def setUp(self):
@@ -1391,9 +1422,9 @@ class InputValidatorTestCase(unittest.TestCase):
 
     def test_validate_input_shape_shapes_match(self):
         # minibatched data
-        # self.plpy_mock_execute.return_value = [{'shape': [1,32,32,3]}]
-        # self.subject.validate_input_shape(
-        #     self.test_table, [self.ind_var], [[32,32,3]], 2, True)
+        self.plpy_mock_execute.return_value = [{'shape': [1,32,32,3]}]
+        self.subject.validate_input_shape(
+            self.test_table, [self.ind_var], [[32,32,3]], 2, True)
         # non-minibatched data
         self.plpy_mock_execute.return_value = [{'shape': [32,32,3]}]
         self.subject.validate_input_shape(

Reply via email to