This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 6c1f737f6991868167af780b269aa2fc029a5369
Author: Nikhil Kak <n...@pivotal.io>
AuthorDate: Mon Apr 29 15:02:44 2019 -0700

    DL: Fix pg bug for getting device name
    
    JIRA: MADLIB-1308
    
    Also created a new test class for madlib keras wrapper functions.
    
    Co-authored-by: Orhan Kislal <okis...@apache.org>
---
 .../deep_learning/madlib_keras_wrapper.py_in       |  2 +-
 .../test/unit_tests/test_madlib_keras.py_in        | 35 +++++++++++++---------
 2 files changed, 22 insertions(+), 15 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 4a11d18..7e5b10a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -63,7 +63,7 @@ def get_device_name_and_set_cuda_env(gpus_per_host, seg):
     if gpus_per_host > 0:
         device_name = '/gpu:0'
         if is_platform_pg():
-            cuda_visible_dev = ','.join([i for i in range(gpus_per_host)])
+            cuda_visible_dev = ','.join([str(i) for i in range(gpus_per_host)])
         else:
             cuda_visible_dev = str(seg % gpus_per_host)
         set_cuda_env(cuda_visible_dev)
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index e315a31..0bdd4d8 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -295,6 +295,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {}
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', None , 
[0], 1, 2,
+
                                                      [0,1,2], [3,3,3], 
'dummy_model_json', "foo", "bar", 0, 4,
                                                      'dummy_prev_state', **k))
         self.assertEqual('dummy_state',
@@ -386,6 +387,26 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         result = self.subject.fit_final(None)
         self.assertEqual(result, None)
 
+
+class MadlibKerasWrapperTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+
+        import madlib_keras_wrapper
+        self.subject = madlib_keras_wrapper
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
     def test_get_device_name_and_set_cuda_env_postgres(self):
         self.subject.is_platform_pg = Mock(return_value = True)
 
@@ -414,20 +435,6 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             gpus_per_host, seg_id))
         self.assertEqual('-1', os.environ['CUDA_VISIBLE_DEVICES'])
 
-    def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
-        k = {}
-        self.assertEqual('dummy_state',
-            self.subject.fit_transition('dummy_state', None , [0], 1, 2,
-            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
-            'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-            self.subject.fit_transition('dummy_state', [[0.5]], None, 1, 2,
-            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
-            'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-            self.subject.fit_transition('dummy_state', None, None, 1, 2,
-            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
-            'dummy_prev_state', **k))
 
     def test_split_and_strip(self):
         self.assertEqual(('a','b'), self.subject.split_and_strip(' a = b '))

Reply via email to