This is an automated email from the ASF dual-hosted git repository.

ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c9e036e40e4 [Python BQ] Retry get_table for quota errors (#28820)
c9e036e40e4 is described below

commit c9e036e40e4f1d21f33ce6829fdf919c934da7de
Author: Ahmed Abualsaud <65791736+ahmedab...@users.noreply.github.com>
AuthorDate: Mon Jan 8 15:52:31 2024 -0500

    [Python BQ] Retry get_table for quota errors (#28820)
    
    * retry get_table on quota errors
    
    * add tests
    
    * only retry on transient reasons
---
 sdks/python/apache_beam/io/gcp/bigquery_test.py  | 201 ++++++++++++++++++++++-
 sdks/python/apache_beam/io/gcp/bigquery_tools.py |   2 +-
 sdks/python/apache_beam/utils/retry.py           |  33 +++-
 3 files changed, 229 insertions(+), 7 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py 
b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index 035edffc03f..e53204a5ebc 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -43,6 +43,7 @@ import apache_beam as beam
 from apache_beam.internal import pickler
 from apache_beam.internal.gcp.json_value import to_json_value
 from apache_beam.io.filebasedsink_test import _TestCaseWithTempDirCleanUp
+from apache_beam.io.filesystems import FileSystems
 from apache_beam.io.gcp import bigquery as beam_bq
 from apache_beam.io.gcp import bigquery_tools
 from apache_beam.io.gcp.bigquery import ReadFromBigQuery
@@ -82,11 +83,13 @@ from apache_beam.transforms.display_test import 
DisplayDataItemMatcher
 try:
   from apache_beam.io.gcp.internal.clients.bigquery import bigquery_v2_client
   from apitools.base.py.exceptions import HttpError
+  from apitools.base.py.exceptions import HttpForbiddenError
   from google.cloud import bigquery as gcp_bigquery
   from google.api_core import exceptions
 except ImportError:
   gcp_bigquery = None
   HttpError = None
+  HttpForbiddenError = None
   exceptions = None
 # pylint: enable=wrong-import-order, wrong-import-position
 
@@ -323,7 +326,9 @@ class TestJsonToDictCoder(unittest.TestCase):
     self.assertEqual(expected_row, actual)
 
 
-@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+@unittest.skipIf(
+    HttpError is None or HttpForbiddenError is None,
+    'GCP dependencies are not installed')
 class TestReadFromBigQuery(unittest.TestCase):
   @classmethod
   def setUpClass(cls):
@@ -454,6 +459,200 @@ class TestReadFromBigQuery(unittest.TestCase):
     mock_insert.assert_called()
     self.assertIn(error_message, exc.exception.args[0])
 
+  @parameterized.expand([
+      # first attempt returns a Http 500 blank error and retries
+      # second attempt returns a Http 408 blank error and retries,
+      # third attempt passes
+      param(
+          responses=[
+              HttpForbiddenError(
+                  response={'status': 500}, content="something", url="")
+              if HttpForbiddenError else None,
+              HttpForbiddenError(
+                  response={'status': 408}, content="blank", url="")
+              if HttpForbiddenError else None
+          ],
+          expected_retries=2),
+      # first attempts returns a 403 rateLimitExceeded error
+      # second attempt returns a 429 blank error
+      # third attempt returns a Http 403 rateLimitExceeded error
+      # fourth attempt passes
+      param(
+          responses=[
+              exceptions.Forbidden(
+                  "some message",
+                  errors=({
+                      "message": "transient", "reason": "rateLimitExceeded"
+                  }, )) if exceptions else None,
+              exceptions.ResourceExhausted("some message")
+              if exceptions else None,
+              HttpForbiddenError(
+                  response={'status': 403},
+                  content={
+                      "error": {
+                          "errors": [{
+                              "message": "transient",
+                              "reason": "rateLimitExceeded"
+                          }]
+                      }
+                  },
+                  url="") if HttpForbiddenError else None,
+          ],
+          expected_retries=3),
+  ])
+  def test_get_table_transient_exception(self, responses, expected_retries):
+    class DummyTable:
+      class DummySchema:
+        fields = []
+
+      numBytes = 5
+      schema = DummySchema()
+
+    with mock.patch('time.sleep'), \
+            mock.patch.object(bigquery_v2_client.BigqueryV2.TablesService,
+                              'Get') as mock_get_table, \
+            mock.patch.object(BigQueryWrapper,
+                              'wait_for_bq_job'), \
+            mock.patch.object(BigQueryWrapper,
+                              'perform_extract_job'), \
+            mock.patch.object(FileSystems,
+                              'match'), \
+            mock.patch.object(FileSystems,
+                              'delete'), \
+            beam.Pipeline() as p:
+      call_counter = 0
+
+      def store_callback(unused_request):
+        nonlocal call_counter
+        if call_counter < len(responses):
+          exception = responses[call_counter]
+          call_counter += 1
+          raise exception
+        else:
+          call_counter += 1
+          return DummyTable()
+
+      mock_get_table.side_effect = store_callback
+      _ = p | beam.io.ReadFromBigQuery(
+          table="project.dataset.table", gcs_location="gs://some_bucket")
+
+    # ReadFromBigQuery export mode calls get_table() twice. Once to get
+    # metadata (numBytes), and once to retrieve the table's schema
+    # Any additional calls are retries
+    self.assertEqual(expected_retries, mock_get_table.call_count - 2)
+
+  @parameterized.expand([
+      # first attempt returns a Http 429 with transient reason and retries
+      # second attempt returns a Http 403 with non-transient reason and fails
+      param(
+          responses=[
+              HttpForbiddenError(
+                  response={'status': 429},
+                  content={
+                      "error": {
+                          "errors": [{
+                              "message": "transient",
+                              "reason": "rateLimitExceeded"
+                          }]
+                      }
+                  },
+                  url="") if HttpForbiddenError else None,
+              HttpForbiddenError(
+                  response={'status': 403},
+                  content={
+                      "error": {
+                          "errors": [{
+                              "message": "transient", "reason": "accessDenied"
+                          }]
+                      }
+                  },
+                  url="") if HttpForbiddenError else None
+          ],
+          expected_retries=1),
+      # first attempt returns a transient 403 error and retries
+      # second attempt returns a 403 error with bad contents and fails
+      param(
+          responses=[
+              HttpForbiddenError(
+                  response={'status': 403},
+                  content={
+                      "error": {
+                          "errors": [{
+                              "message": "transient",
+                              "reason": "rateLimitExceeded"
+                          }]
+                      }
+                  },
+                  url="") if HttpForbiddenError else None,
+              HttpError(
+                  response={'status': 403}, content="bad contents", url="")
+              if HttpError else None
+          ],
+          expected_retries=1),
+      # first attempt returns a transient 403 error and retries
+      # second attempt returns a 429 error and retries
+      # third attempt returns a 403 with non-transient reason and fails
+      param(
+          responses=[
+              exceptions.Forbidden(
+                  "some error",
+                  errors=({
+                      "message": "transient", "reason": "rateLimitExceeded"
+                  }, )) if exceptions else None,
+              exceptions.ResourceExhausted("some transient error")
+              if exceptions else None,
+              exceptions.Forbidden(
+                  "some error",
+                  errors=({
+                      "message": "transient", "reason": "accessDenied"
+                  }, )) if exceptions else None,
+          ],
+          expected_retries=2),
+  ])
+  def test_get_table_non_transient_exception(self, responses, 
expected_retries):
+    class DummyTable:
+      class DummySchema:
+        fields = []
+
+      numBytes = 5
+      schema = DummySchema()
+
+    with mock.patch('time.sleep'), \
+            mock.patch.object(bigquery_v2_client.BigqueryV2.TablesService,
+                              'Get') as mock_get_table, \
+            mock.patch.object(BigQueryWrapper,
+                              'wait_for_bq_job'), \
+            mock.patch.object(BigQueryWrapper,
+                              'perform_extract_job'), \
+            mock.patch.object(FileSystems,
+                              'match'), \
+            mock.patch.object(FileSystems,
+                              'delete'), \
+            self.assertRaises(Exception), \
+            beam.Pipeline() as p:
+      call_counter = 0
+
+      def store_callback(unused_request):
+        nonlocal call_counter
+        if call_counter < len(responses):
+          exception = responses[call_counter]
+          call_counter += 1
+          raise exception
+        else:
+          call_counter += 1
+          return DummyTable()
+
+      mock_get_table.side_effect = store_callback
+      _ = p | beam.io.ReadFromBigQuery(
+          table="project.dataset.table", gcs_location="gs://some_bucket")
+
+    # ReadFromBigQuery export mode calls get_table() twice. Once to get
+    # metadata (numBytes), and once to retrieve the table's schema
+    # However, the second call is never reached because this test will always
+    # fail before it does so
+    # After the first call, any additional calls are retries
+    self.assertEqual(expected_retries, mock_get_table.call_count - 1)
+
   @parameterized.expand([
       param(
           exception_type=exceptions.BadRequest if exceptions else None,
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py 
b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
index 41ef57b5b27..fab5a861159 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
@@ -751,7 +751,7 @@ class BigQueryWrapper(object):
 
   @retry.with_exponential_backoff(
       num_retries=MAX_RETRIES,
-      retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+      retry_filter=retry.retry_on_server_errors_timeout_or_quota_issues_filter)
   def get_table(self, project_id, dataset_id, table_id):
     """Lookup a table's metadata object.
 
diff --git a/sdks/python/apache_beam/utils/retry.py 
b/sdks/python/apache_beam/utils/retry.py
index 6eed2900b9a..485fc9d627e 100644
--- a/sdks/python/apache_beam/utils/retry.py
+++ b/sdks/python/apache_beam/utils/retry.py
@@ -28,6 +28,7 @@ needed right now use a @retry.no_retries decorator.
 # pytype: skip-file
 
 import functools
+import json
 import logging
 import random
 import sys
@@ -57,6 +58,7 @@ else:
 # pylint: enable=wrong-import-order, wrong-import-position
 
 _LOGGER = logging.getLogger(__name__)
+_RETRYABLE_REASONS = ["rateLimitExceeded", "internalError", "backendError"]
 
 
 class PermanentException(Exception):
@@ -166,17 +168,38 @@ def retry_on_server_errors_and_timeout_filter(exception):
 
 
 def retry_on_server_errors_timeout_or_quota_issues_filter(exception):
-  """Retry on server, timeout and 403 errors.
+  """Retry on server, timeout, 429, and some 403 errors.
 
-  403 errors can be accessDenied, billingNotEnabled, and also quotaExceeded,
-  rateLimitExceeded."""
+  403 errors from BigQuery include both non-transient (accessDenied,
+  billingNotEnabled) and transient errors (rateLimitExceeded).
+  Only retry transient errors."""
   if HttpError is not None and isinstance(exception, HttpError):
-    if exception.status_code == 403:
+    if exception.status_code == 429:
       return True
+    if exception.status_code == 403:
+      try:
+        # attempt to extract the reason and check if it's retryable
+        content = exception.content
+        if not isinstance(content, dict):
+          content = json.loads(exception.content)
+        return content["error"]["errors"][0]["reason"] in _RETRYABLE_REASONS
+      except (KeyError, IndexError, TypeError) as e:
+        _LOGGER.warning(
+            "Could not determine if HttpError is transient. "
+            "Will not retry: %s",
+            e)
+      return False
   if GoogleAPICallError is not None and isinstance(exception,
                                                    GoogleAPICallError):
-    if exception.code == 403:
+    if exception.code == 429:
       return True
+    if exception.code == 403:
+      if not hasattr(exception, "errors") or len(exception.errors) == 0:
+        # default to not retrying
+        return False
+
+      reason = exception.errors[0]["reason"]
+      return reason in _RETRYABLE_REASONS
   if S3ClientError is not None and isinstance(exception, S3ClientError):
     if exception.code == 403:
       return True

Reply via email to