This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f792e2e  Add helper functions for reading and writing to PubSub 
directly from Python (#9212)
f792e2e is described below

commit f792e2e46925ace3e0221ff6bf17fdede3383fbd
Author: Alexey Strokach <strok...@google.com>
AuthorDate: Wed Aug 7 17:16:18 2019 -0700

    Add helper functions for reading and writing to PubSub directly from Python 
(#9212)
    
    * Add helper functions for reading and writing to PubSub directly from 
Python
    
    These functions are helpful when writing tests and when working with 
streaming pipelines interactively (e.g. inside a Jupyter notebook).
    
    Notes:
    
    - Not sure if apache_beam/testing/test_utils.py is a better place for the 
helper functions than apache_beam/io/gcp/tests/utils.py?
    - google.cloud.exceptions seems to have moved to 
google.api_core.exceptions. Currently, google.cloud.exceptions re-imports some, 
but not all, of the exceptions defined in google.api_core.exceptions.
---
 sdks/python/apache_beam/io/gcp/tests/utils.py      |  57 +++++-
 sdks/python/apache_beam/io/gcp/tests/utils_test.py | 200 ++++++++++++++++++++-
 2 files changed, 249 insertions(+), 8 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/tests/utils.py 
b/sdks/python/apache_beam/io/gcp/tests/utils.py
index 68d3f43..4ed9af3 100644
--- a/sdks/python/apache_beam/io/gcp/tests/utils.py
+++ b/sdks/python/apache_beam/io/gcp/tests/utils.py
@@ -25,15 +25,16 @@ import random
 import time
 
 from apache_beam.io import filesystems
+from apache_beam.io.gcp.pubsub import PubsubMessage
 from apache_beam.utils import retry
 
 # Protect against environments where bigquery library is not available.
 try:
+  from google.api_core import exceptions as gexc
   from google.cloud import bigquery
-  from google.cloud.exceptions import NotFound
 except ImportError:
+  gexc = None
   bigquery = None
-  NotFound = None
 
 
 class GcpTestIOError(retry.PermanentException):
@@ -98,7 +99,7 @@ def delete_bq_table(project, dataset_id, table_id):
   table_ref = client.dataset(dataset_id).table(table_id)
   try:
     client.delete_table(table_ref)
-  except NotFound:
+  except gexc.NotFound:
     raise GcpTestIOError('BigQuery table does not exist: %s' % table_ref)
 
 
@@ -113,3 +114,53 @@ def delete_directory(directory):
       "gs://mybucket/mydir/", "s3://...", ...)
   """
   filesystems.FileSystems.delete([directory])
+
+
+def write_to_pubsub(pub_client,
+                    topic_path,
+                    messages,
+                    with_attributes=False,
+                    chunk_size=100,
+                    delay_between_chunks=0.1):
+  for start in range(0, len(messages), chunk_size):
+    message_chunk = messages[start:start + chunk_size]
+    if with_attributes:
+      futures = [
+          pub_client.publish(topic_path, message.data, **message.attributes)
+          for message in message_chunk
+      ]
+    else:
+      futures = [
+          pub_client.publish(topic_path, message) for message in message_chunk
+      ]
+    for future in futures:
+      future.result()
+    time.sleep(delay_between_chunks)
+
+
+def read_from_pubsub(sub_client,
+                     subscription_path,
+                     with_attributes=False,
+                     number_of_elements=None,
+                     timeout=None):
+  if number_of_elements is None and timeout is None:
+    raise ValueError("Either number_of_elements or timeout must be specified.")
+  messages = []
+  start_time = time.time()
+
+  while ((number_of_elements is None or len(messages) < number_of_elements) and
+         (timeout is None or (time.time() - start_time) < timeout)):
+    try:
+      response = sub_client.pull(
+          subscription_path, max_messages=1000, retry=None, timeout=10)
+    except (gexc.RetryError, gexc.DeadlineExceeded):
+      continue
+    ack_ids = [msg.ack_id for msg in response.received_messages]
+    sub_client.acknowledge(subscription_path, ack_ids)
+    for msg in response.received_messages:
+      message = PubsubMessage._from_message(msg.message)
+      if with_attributes:
+        messages.append(message)
+      else:
+        messages.append(message.data)
+  return messages
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils_test.py 
b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
index 8af7497..c9e96d1 100644
--- a/sdks/python/apache_beam/io/gcp/tests/utils_test.py
+++ b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
@@ -24,16 +24,19 @@ import unittest
 
 import mock
 
+from apache_beam.io.gcp.pubsub import PubsubMessage
 from apache_beam.io.gcp.tests import utils
-from apache_beam.testing.test_utils import patch_retry
+from apache_beam.testing import test_utils
 
 # Protect against environments where bigquery library is not available.
 try:
+  from google.api_core import exceptions as gexc
   from google.cloud import bigquery
-  from google.cloud.exceptions import NotFound
+  from google.cloud import pubsub
 except ImportError:
+  gexc = None
   bigquery = None
-  NotFound = None
+  pubsub = None
 
 
 @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.')
@@ -41,7 +44,7 @@ except ImportError:
 class UtilsTest(unittest.TestCase):
 
   def setUp(self):
-    patch_retry(self, utils)
+    test_utils.patch_retry(self, utils)
 
   @mock.patch.object(bigquery, 'Dataset')
   def test_create_bq_dataset(self, mock_dataset, mock_client):
@@ -68,7 +71,7 @@ class UtilsTest(unittest.TestCase):
   def test_delete_table_fails_not_found(self, mock_client):
     mock_client.return_value.dataset.return_value.table.return_value = (
         'table_ref')
-    mock_client.return_value.delete_table.side_effect = NotFound('test')
+    mock_client.return_value.delete_table.side_effect = gexc.NotFound('test')
 
     with self.assertRaisesRegexp(Exception, r'does not exist:.*table_ref'):
       utils.delete_bq_table('unused_project',
@@ -76,6 +79,193 @@ class UtilsTest(unittest.TestCase):
                             'unused_table')
 
 
+@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
+class PubSubUtilTest(unittest.TestCase):
+
+  def test_write_to_pubsub(self):
+    mock_pubsub = mock.Mock()
+    topic_path = "project/fakeproj/topics/faketopic"
+    data = b'data'
+    utils.write_to_pubsub(mock_pubsub, topic_path, [data])
+    mock_pubsub.publish.assert_has_calls(
+        [mock.call(topic_path, data),
+         mock.call().result()])
+
+  def test_write_to_pubsub_with_attributes(self):
+    mock_pubsub = mock.Mock()
+    topic_path = "project/fakeproj/topics/faketopic"
+    data = b'data'
+    attributes = {'key': 'value'}
+    message = PubsubMessage(data, attributes)
+    utils.write_to_pubsub(
+        mock_pubsub, topic_path, [message], with_attributes=True)
+    mock_pubsub.publish.assert_has_calls(
+        [mock.call(topic_path, data, **attributes),
+         mock.call().result()])
+
+  def test_write_to_pubsub_delay(self):
+    number_of_elements = 2
+    chunk_size = 1
+    mock_pubsub = mock.Mock()
+    topic_path = "project/fakeproj/topics/faketopic"
+    data = b'data'
+    with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time:
+      utils.write_to_pubsub(
+          mock_pubsub,
+          topic_path, [data] * number_of_elements,
+          chunk_size=chunk_size,
+          delay_between_chunks=123)
+    mock_time.sleep.assert_called_with(123)
+    mock_pubsub.publish.assert_has_calls(
+        [mock.call(topic_path, data),
+         mock.call().result()] * number_of_elements)
+
+  def test_write_to_pubsub_many_chunks(self):
+    number_of_elements = 83
+    chunk_size = 11
+    mock_pubsub = mock.Mock()
+    topic_path = "project/fakeproj/topics/faketopic"
+    data_list = [
+        'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
+    ]
+    utils.write_to_pubsub(
+        mock_pubsub, topic_path, data_list, chunk_size=chunk_size)
+    call_list = []
+    for start in range(0, number_of_elements, chunk_size):
+      # Publish a batch of messages
+      call_list += [
+          mock.call(topic_path, data)
+          for data in data_list[start:start + chunk_size]
+      ]
+      # Wait for those messages to be received
+      call_list += [
+          mock.call().result() for _ in data_list[start:start + chunk_size]
+      ]
+    mock_pubsub.publish.assert_has_calls(call_list)
+
+  def test_read_from_pubsub(self):
+    mock_pubsub = mock.Mock()
+    subscription_path = "project/fakeproj/subscriptions/fakesub"
+    data = b'data'
+    ack_id = 'ack_id'
+    pull_response = test_utils.create_pull_response(
+        [test_utils.PullResponseMessage(data, ack_id=ack_id)])
+    mock_pubsub.pull.return_value = pull_response
+    output = utils.read_from_pubsub(
+        mock_pubsub, subscription_path, number_of_elements=1)
+    self.assertEqual([data], output)
+    mock_pubsub.acknowledge.assert_called_once_with(subscription_path, 
[ack_id])
+
+  def test_read_from_pubsub_with_attributes(self):
+    mock_pubsub = mock.Mock()
+    subscription_path = "project/fakeproj/subscriptions/fakesub"
+    data = b'data'
+    ack_id = 'ack_id'
+    attributes = {'key': 'value'}
+    message = PubsubMessage(data, attributes)
+    pull_response = test_utils.create_pull_response(
+        [test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)])
+    mock_pubsub.pull.return_value = pull_response
+    output = utils.read_from_pubsub(
+        mock_pubsub,
+        subscription_path,
+        with_attributes=True,
+        number_of_elements=1)
+    self.assertEqual([message], output)
+    mock_pubsub.acknowledge.assert_called_once_with(subscription_path, 
[ack_id])
+
+  def test_read_from_pubsub_flaky(self):
+    number_of_elements = 10
+    mock_pubsub = mock.Mock()
+    subscription_path = "project/fakeproj/subscriptions/fakesub"
+    data = b'data'
+    ack_id = 'ack_id'
+    pull_response = test_utils.create_pull_response(
+        [test_utils.PullResponseMessage(data, ack_id=ack_id)])
+
+    class FlakyPullResponse(object):
+
+      def __init__(self, pull_response):
+        self.pull_response = pull_response
+        self._state = -1
+
+      def __call__(self, *args, **kwargs):
+        self._state += 1
+        if self._state % 3 == 0:
+          raise gexc.RetryError("", "")
+        if self._state % 3 == 1:
+          raise gexc.DeadlineExceeded("")
+        if self._state % 3 == 2:
+          return self.pull_response
+
+    mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response)
+    output = utils.read_from_pubsub(
+        mock_pubsub, subscription_path, number_of_elements=number_of_elements)
+    self.assertEqual([data] * number_of_elements, output)
+    self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements)
+
+  def test_read_from_pubsub_many(self):
+    response_size = 33
+    number_of_elements = 100
+    mock_pubsub = mock.Mock()
+    subscription_path = "project/fakeproj/subscriptions/fakesub"
+    data_list = [
+        'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
+    ]
+    attributes_list = [{
+        'key': 'value {}'.format(i)
+    } for i in range(number_of_elements)]
+    ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)]
+    messages = [
+        PubsubMessage(data, attributes)
+        for data, attributes in zip(data_list, attributes_list)
+    ]
+    response_messages = [
+        test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)
+        for data, attributes, ack_id in zip(data_list, attributes_list, 
ack_ids)
+    ]
+
+    class SequentialPullResponse(object):
+
+      def __init__(self, response_messages, response_size):
+        self.response_messages = response_messages
+        self.response_size = response_size
+        self._index = 0
+
+      def __call__(self, *args, **kwargs):
+        start = self._index
+        self._index += self.response_size
+        response = test_utils.create_pull_response(
+            self.response_messages[start:start + self.response_size])
+        return response
+
+    mock_pubsub.pull.side_effect = SequentialPullResponse(
+        response_messages, response_size)
+    output = utils.read_from_pubsub(
+        mock_pubsub,
+        subscription_path,
+        with_attributes=True,
+        number_of_elements=number_of_elements)
+    self.assertEqual(messages, output)
+    self._assert_ack_ids_equal(mock_pubsub, ack_ids)
+
+  def test_read_from_pubsub_invalid_arg(self):
+    sub_client = mock.Mock()
+    subscription_path = "project/fakeproj/subscriptions/fakesub"
+    with self.assertRaisesRegexp(ValueError, "number_of_elements"):
+      utils.read_from_pubsub(sub_client, subscription_path)
+    with self.assertRaisesRegexp(ValueError, "number_of_elements"):
+      utils.read_from_pubsub(
+          sub_client, subscription_path, with_attributes=True)
+
+  def _assert_ack_ids_equal(self, mock_pubsub, ack_ids):
+    actual_ack_ids = [
+        ack_id for args_list in mock_pubsub.acknowledge.call_args_list
+        for ack_id in args_list[0][1]
+    ]
+    self.assertEqual(actual_ack_ids, ack_ids)
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to