This is an automated email from the ASF dual-hosted git repository. altay pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new f792e2e Add helper functions for reading and writing to PubSub directly from Python (#9212) f792e2e is described below commit f792e2e46925ace3e0221ff6bf17fdede3383fbd Author: Alexey Strokach <strok...@google.com> AuthorDate: Wed Aug 7 17:16:18 2019 -0700 Add helper functions for reading and writing to PubSub directly from Python (#9212) * Add helper functions for reading and writing to PubSub directly from Python These functions are helpful when writing tests and when working with streaming pipelines interactively (e.g. inside a Jupyter notebook). Notes: - Not sure if apache_beam/testing/test_utils.py is a better place for the helper functions than apache_beam/io/gcp/tests/utils.py? - google.cloud.exceptions seems to have moved to google.api_core.exceptions. Currently, google.cloud.exceptions re-imports some, but not all, of the exceptions defined in google.api_core.exceptions. --- sdks/python/apache_beam/io/gcp/tests/utils.py | 57 +++++- sdks/python/apache_beam/io/gcp/tests/utils_test.py | 200 ++++++++++++++++++++- 2 files changed, 249 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/tests/utils.py b/sdks/python/apache_beam/io/gcp/tests/utils.py index 68d3f43..4ed9af3 100644 --- a/sdks/python/apache_beam/io/gcp/tests/utils.py +++ b/sdks/python/apache_beam/io/gcp/tests/utils.py @@ -25,15 +25,16 @@ import random import time from apache_beam.io import filesystems +from apache_beam.io.gcp.pubsub import PubsubMessage from apache_beam.utils import retry # Protect against environments where bigquery library is not available. try: + from google.api_core import exceptions as gexc from google.cloud import bigquery - from google.cloud.exceptions import NotFound except ImportError: + gexc = None bigquery = None - NotFound = None class GcpTestIOError(retry.PermanentException): @@ -98,7 +99,7 @@ def delete_bq_table(project, dataset_id, table_id): table_ref = client.dataset(dataset_id).table(table_id) try: client.delete_table(table_ref) - except NotFound: + except gexc.NotFound: raise GcpTestIOError('BigQuery table does not exist: %s' % table_ref) @@ -113,3 +114,53 @@ def delete_directory(directory): "gs://mybucket/mydir/", "s3://...", ...) """ filesystems.FileSystems.delete([directory]) + + +def write_to_pubsub(pub_client, + topic_path, + messages, + with_attributes=False, + chunk_size=100, + delay_between_chunks=0.1): + for start in range(0, len(messages), chunk_size): + message_chunk = messages[start:start + chunk_size] + if with_attributes: + futures = [ + pub_client.publish(topic_path, message.data, **message.attributes) + for message in message_chunk + ] + else: + futures = [ + pub_client.publish(topic_path, message) for message in message_chunk + ] + for future in futures: + future.result() + time.sleep(delay_between_chunks) + + +def read_from_pubsub(sub_client, + subscription_path, + with_attributes=False, + number_of_elements=None, + timeout=None): + if number_of_elements is None and timeout is None: + raise ValueError("Either number_of_elements or timeout must be specified.") + messages = [] + start_time = time.time() + + while ((number_of_elements is None or len(messages) < number_of_elements) and + (timeout is None or (time.time() - start_time) < timeout)): + try: + response = sub_client.pull( + subscription_path, max_messages=1000, retry=None, timeout=10) + except (gexc.RetryError, gexc.DeadlineExceeded): + continue + ack_ids = [msg.ack_id for msg in response.received_messages] + sub_client.acknowledge(subscription_path, ack_ids) + for msg in response.received_messages: + message = PubsubMessage._from_message(msg.message) + if with_attributes: + messages.append(message) + else: + messages.append(message.data) + return messages diff --git a/sdks/python/apache_beam/io/gcp/tests/utils_test.py b/sdks/python/apache_beam/io/gcp/tests/utils_test.py index 8af7497..c9e96d1 100644 --- a/sdks/python/apache_beam/io/gcp/tests/utils_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/utils_test.py @@ -24,16 +24,19 @@ import unittest import mock +from apache_beam.io.gcp.pubsub import PubsubMessage from apache_beam.io.gcp.tests import utils -from apache_beam.testing.test_utils import patch_retry +from apache_beam.testing import test_utils # Protect against environments where bigquery library is not available. try: + from google.api_core import exceptions as gexc from google.cloud import bigquery - from google.cloud.exceptions import NotFound + from google.cloud import pubsub except ImportError: + gexc = None bigquery = None - NotFound = None + pubsub = None @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') @@ -41,7 +44,7 @@ except ImportError: class UtilsTest(unittest.TestCase): def setUp(self): - patch_retry(self, utils) + test_utils.patch_retry(self, utils) @mock.patch.object(bigquery, 'Dataset') def test_create_bq_dataset(self, mock_dataset, mock_client): @@ -68,7 +71,7 @@ class UtilsTest(unittest.TestCase): def test_delete_table_fails_not_found(self, mock_client): mock_client.return_value.dataset.return_value.table.return_value = ( 'table_ref') - mock_client.return_value.delete_table.side_effect = NotFound('test') + mock_client.return_value.delete_table.side_effect = gexc.NotFound('test') with self.assertRaisesRegexp(Exception, r'does not exist:.*table_ref'): utils.delete_bq_table('unused_project', @@ -76,6 +79,193 @@ class UtilsTest(unittest.TestCase): 'unused_table') +@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') +class PubSubUtilTest(unittest.TestCase): + + def test_write_to_pubsub(self): + mock_pubsub = mock.Mock() + topic_path = "project/fakeproj/topics/faketopic" + data = b'data' + utils.write_to_pubsub(mock_pubsub, topic_path, [data]) + mock_pubsub.publish.assert_has_calls( + [mock.call(topic_path, data), + mock.call().result()]) + + def test_write_to_pubsub_with_attributes(self): + mock_pubsub = mock.Mock() + topic_path = "project/fakeproj/topics/faketopic" + data = b'data' + attributes = {'key': 'value'} + message = PubsubMessage(data, attributes) + utils.write_to_pubsub( + mock_pubsub, topic_path, [message], with_attributes=True) + mock_pubsub.publish.assert_has_calls( + [mock.call(topic_path, data, **attributes), + mock.call().result()]) + + def test_write_to_pubsub_delay(self): + number_of_elements = 2 + chunk_size = 1 + mock_pubsub = mock.Mock() + topic_path = "project/fakeproj/topics/faketopic" + data = b'data' + with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time: + utils.write_to_pubsub( + mock_pubsub, + topic_path, [data] * number_of_elements, + chunk_size=chunk_size, + delay_between_chunks=123) + mock_time.sleep.assert_called_with(123) + mock_pubsub.publish.assert_has_calls( + [mock.call(topic_path, data), + mock.call().result()] * number_of_elements) + + def test_write_to_pubsub_many_chunks(self): + number_of_elements = 83 + chunk_size = 11 + mock_pubsub = mock.Mock() + topic_path = "project/fakeproj/topics/faketopic" + data_list = [ + 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements) + ] + utils.write_to_pubsub( + mock_pubsub, topic_path, data_list, chunk_size=chunk_size) + call_list = [] + for start in range(0, number_of_elements, chunk_size): + # Publish a batch of messages + call_list += [ + mock.call(topic_path, data) + for data in data_list[start:start + chunk_size] + ] + # Wait for those messages to be received + call_list += [ + mock.call().result() for _ in data_list[start:start + chunk_size] + ] + mock_pubsub.publish.assert_has_calls(call_list) + + def test_read_from_pubsub(self): + mock_pubsub = mock.Mock() + subscription_path = "project/fakeproj/subscriptions/fakesub" + data = b'data' + ack_id = 'ack_id' + pull_response = test_utils.create_pull_response( + [test_utils.PullResponseMessage(data, ack_id=ack_id)]) + mock_pubsub.pull.return_value = pull_response + output = utils.read_from_pubsub( + mock_pubsub, subscription_path, number_of_elements=1) + self.assertEqual([data], output) + mock_pubsub.acknowledge.assert_called_once_with(subscription_path, [ack_id]) + + def test_read_from_pubsub_with_attributes(self): + mock_pubsub = mock.Mock() + subscription_path = "project/fakeproj/subscriptions/fakesub" + data = b'data' + ack_id = 'ack_id' + attributes = {'key': 'value'} + message = PubsubMessage(data, attributes) + pull_response = test_utils.create_pull_response( + [test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)]) + mock_pubsub.pull.return_value = pull_response + output = utils.read_from_pubsub( + mock_pubsub, + subscription_path, + with_attributes=True, + number_of_elements=1) + self.assertEqual([message], output) + mock_pubsub.acknowledge.assert_called_once_with(subscription_path, [ack_id]) + + def test_read_from_pubsub_flaky(self): + number_of_elements = 10 + mock_pubsub = mock.Mock() + subscription_path = "project/fakeproj/subscriptions/fakesub" + data = b'data' + ack_id = 'ack_id' + pull_response = test_utils.create_pull_response( + [test_utils.PullResponseMessage(data, ack_id=ack_id)]) + + class FlakyPullResponse(object): + + def __init__(self, pull_response): + self.pull_response = pull_response + self._state = -1 + + def __call__(self, *args, **kwargs): + self._state += 1 + if self._state % 3 == 0: + raise gexc.RetryError("", "") + if self._state % 3 == 1: + raise gexc.DeadlineExceeded("") + if self._state % 3 == 2: + return self.pull_response + + mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response) + output = utils.read_from_pubsub( + mock_pubsub, subscription_path, number_of_elements=number_of_elements) + self.assertEqual([data] * number_of_elements, output) + self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements) + + def test_read_from_pubsub_many(self): + response_size = 33 + number_of_elements = 100 + mock_pubsub = mock.Mock() + subscription_path = "project/fakeproj/subscriptions/fakesub" + data_list = [ + 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements) + ] + attributes_list = [{ + 'key': 'value {}'.format(i) + } for i in range(number_of_elements)] + ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)] + messages = [ + PubsubMessage(data, attributes) + for data, attributes in zip(data_list, attributes_list) + ] + response_messages = [ + test_utils.PullResponseMessage(data, attributes, ack_id=ack_id) + for data, attributes, ack_id in zip(data_list, attributes_list, ack_ids) + ] + + class SequentialPullResponse(object): + + def __init__(self, response_messages, response_size): + self.response_messages = response_messages + self.response_size = response_size + self._index = 0 + + def __call__(self, *args, **kwargs): + start = self._index + self._index += self.response_size + response = test_utils.create_pull_response( + self.response_messages[start:start + self.response_size]) + return response + + mock_pubsub.pull.side_effect = SequentialPullResponse( + response_messages, response_size) + output = utils.read_from_pubsub( + mock_pubsub, + subscription_path, + with_attributes=True, + number_of_elements=number_of_elements) + self.assertEqual(messages, output) + self._assert_ack_ids_equal(mock_pubsub, ack_ids) + + def test_read_from_pubsub_invalid_arg(self): + sub_client = mock.Mock() + subscription_path = "project/fakeproj/subscriptions/fakesub" + with self.assertRaisesRegexp(ValueError, "number_of_elements"): + utils.read_from_pubsub(sub_client, subscription_path) + with self.assertRaisesRegexp(ValueError, "number_of_elements"): + utils.read_from_pubsub( + sub_client, subscription_path, with_attributes=True) + + def _assert_ack_ids_equal(self, mock_pubsub, ack_ids): + actual_ack_ids = [ + ack_id for args_list in mock_pubsub.acknowledge.call_args_list + for ack_id in args_list[0][1] + ] + self.assertEqual(actual_ack_ids, ack_ids) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()