njayaram2 commented on a change in pull request #389: DL: Convert the
keras_eval function from UDF to UDA
URL: https://github.com/apache/madlib/pull/389#discussion_r283059111
##########
File path:
src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -803,6 +810,267 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
self.subject.strip_trailing_nulls_from_class_values(
[None, None]))
+
+class MadlibKerasEvaluationTestCase(unittest.TestCase):
+ def setUp(self):
+ self.plpy_mock = Mock(spec='error')
+ patches = {
+ 'plpy': plpy,
+ 'utilities.minibatch_preprocessing': Mock()
+ }
+
+ self.plpy_mock_execute = MagicMock()
+ plpy.execute = self.plpy_mock_execute
+
+ self.module_patcher = patch.dict('sys.modules', patches)
+ self.module_patcher.start()
+ import madlib_keras
+ self.subject = madlib_keras
+
+ self.model = Sequential()
+ self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
+ input_shape=(1,1,1,), padding='same'))
+ self.model.add(Flatten())
+
+ self.compile_params = "optimizer=SGD(lr=0.01, decay=1e-6,
nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']"
+ self.model_weights = [3,4,5,6]
+ self.model_shapes = []
+ for a in self.model.get_weights():
+ self.model_shapes.append(a.shape)
+
+ self.loss = 0.5947071313858032
+ self.accuracy = 1.0
+ self.all_seg_ids = [0,1,2]
+
+ #self.model.evaluate = Mock(return_value = [self.loss, self.accuracy])
+
+ self.independent_var = [[[[0.5]]]] * 10
+ self.dependent_var = [[0,1]] * 10
+ # We test on segment 0, which has 3 buffers filled with 10 identical
+ # images each, or 30 images total
+ self.total_images_per_seg = [3*len(self.dependent_var),20,40]
+
+ def tearDown(self):
+ self.module_patcher.stop()
+
+ def _test_internal_keras_eval_transition_first_buffer(self,
is_platform_pg):
+ self.subject.K.set_session = Mock()
+ self.subject.clear_keras_session = Mock()
+ self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+ starting_image_count = 0
+ ending_image_count = len(self.dependent_var)
+
+ k = {'SD' : {}}
+ state = [0,0,0]
+
+ serialized_weights = [0, 0, 0] # not used
+ serialized_weights.extend(self.model_weights)
+ serialized_weights = np.array(serialized_weights,
dtype=np.float32).tostring()
+
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
self.model.to_json(), serialized_weights,
+ self.compile_params, 0, 3, self.all_seg_ids,
self.total_images_per_seg,
+ 0, **k)
+
+ agg_loss, agg_accuracy, image_count = new_state
+
+ self.assertEqual(ending_image_count, image_count)
+ # Call set_session once for gpdb (but not for postgres)
+ self.assertEqual(0 if is_platform_pg else 1,
self.subject.K.set_session.call_count)
+ # loss and accuracy should be unchanged
+ self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
+ # Clear session and sess.close must not get called for the first buffer
+ self.assertEqual(0, self.subject.clear_keras_session.call_count)
+ self.assertTrue(k['SD']['segment_model'])
+
+ def _test_internal_keras_eval_transition_middle_buffer(self,
is_platform_pg):
+ #TODO should we mock tensorflow's close_session and keras'
+ # clear_session instead of mocking the function `clear_keras_session`
+ self.subject.K.set_session = Mock()
+ self.subject.clear_keras_session = Mock()
+ self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+
+ starting_image_count = len(self.dependent_var)
+ ending_image_count = starting_image_count + len(self.dependent_var)
+
+ k = {'SD' : {}}
+
+ model_state = [self.loss, self.accuracy, starting_image_count]
+ model_state.extend(self.model_weights)
+ model_state = np.array(model_state, dtype=np.float32)
+
+ self.subject.compile_and_set_weights(self.model, self.compile_params,
+ '/cpu:0', model_state.tostring(),
self.model_shapes)
+
+ state = [self.loss * starting_image_count, self.accuracy *
starting_image_count, starting_image_count]
+ k['SD']['segment_model'] = self.model
+
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
self.model.to_json(), 'dummy_model_data',
+ None, 0, 3, self.all_seg_ids, self.total_images_per_seg,
+ 0, **k)
+
+ agg_loss, agg_accuracy, image_count = new_state
+
+ self.assertEqual(ending_image_count, image_count)
+ # set_session is only called in first buffer, not here
+ self.assertEqual(0, self.subject.K.set_session.call_count)
+ # loss and accuracy should be unchanged
+ self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * ending_image_count,
agg_accuracy, 4)
+ # Clear session and sess.close must not get called for the middle
buffer
+ self.assertEqual(0, self.subject.clear_keras_session.call_count)
+
+ def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg):
+ #TODO should we mock tensorflow's close_session and keras'
+ # clear_session instead of mocking the function `clear_keras_session`
+ self.subject.K.set_session = Mock()
+ self.subject.clear_keras_session = Mock()
+ self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+
+ starting_image_count = 2*len(self.dependent_var)
+ ending_image_count = starting_image_count + len(self.dependent_var)
+ k = {'SD' : {}}
+
+ model_state = [self.loss, self.accuracy, starting_image_count]
+ model_state.extend(self.model_weights)
+ model_state = np.array(model_state, dtype=np.float32)
+
+ self.subject.compile_and_set_weights(self.model, self.compile_params,
+ '/cpu:0', model_state.tostring(),
self.model_shapes)
+
+ state = [self.loss * starting_image_count, self.accuracy *
starting_image_count, starting_image_count]
+
+ k['SD']['segment_model'] = self.model
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
self.model.to_json(), 'dummy_model_data',
+ None, 0, 3, self.all_seg_ids, self.total_images_per_seg,
+ 0, **k)
+
+ agg_loss, agg_accuracy, image_count = new_state
+
+ self.assertEqual(ending_image_count, image_count)
+ # set_session is only called in first buffer, not here
+ self.assertEqual(0, self.subject.K.set_session.call_count)
+ # loss and accuracy should be unchanged
+ self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * ending_image_count,
agg_accuracy, 4)
+ # Clear session and sess.close must get called for the last buffer in
gpdb,
+ # but not in postgres
+ self.assertEqual(0 if is_platform_pg else 1,
self.subject.clear_keras_session.call_count)
+
+ def test_internal_keras_eval_transition_first_buffer_pg(self):
+ self._test_internal_keras_eval_transition_first_buffer(True)
+
+ def test_internal_keras_eval_transition_first_buffer_gpdb(self):
+ self._test_internal_keras_eval_transition_first_buffer(False)
+
+ def test_internal_keras_eval_transition_middle_buffer_pg(self):
+ self._test_internal_keras_eval_transition_middle_buffer(True)
+
+ def test_internal_keras_eval_transition_middle_buffer_gpdb(self):
+ self._test_internal_keras_eval_transition_middle_buffer(False)
+
+ def test_internal_keras_eval_transition_last_buffer_pg(self):
+ self._test_internal_keras_eval_transition_last_buffer(True)
+
+ def test_internal_keras_eval_transition_last_buffer_gpdb(self):
+ self._test_internal_keras_eval_transition_last_buffer(False)
+
+ def test_internal_keras_eval_merge(self):
+ image_count = self.total_images_per_seg[0]
+ state1 = [3.0*self.loss, 3.0*self.accuracy, image_count]
+ state1 = state1
+ state2 = [2.0*self.loss, 2.0*self.accuracy, image_count+30]
+ state2 = state2
+ merged_state = self.subject.internal_keras_eval_merge(state1,state2)
+ state = merged_state
+ agg_loss = state[0]
+ agg_accuracy = state[1]
+ image_count_total = state[2]
Review comment:
We can directly use `merged_state` instead of `state`, and remove the
`state` variable.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services