This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 35969b34a37 Introduce more abstract handling of throttling signaling 
(#35984)
35969b34a37 is described below

commit 35969b34a3734c495d4be9e0e99e399c7c9ed231
Author: Jack McCluskey <[email protected]>
AuthorDate: Tue Sep 2 16:43:18 2025 -0400

    Introduce more abstract handling of throttling signaling (#35984)
    
    * Introduce more abstract handling of throttling signaling
    
    * linting
    
    * gemini suggestions
    
    * review suggestions
    
    * Update sdks/python/apache_beam/io/components/adaptive_throttler.py
    
    Co-authored-by: Danny McCormick <[email protected]>
    
    * linting
    
    ---------
    
    Co-authored-by: Danny McCormick <[email protected]>
---
 CHANGES.md                                         |  1 +
 .../io/components/adaptive_throttler.py            | 92 ++++++++++++++++++++++
 sdks/python/apache_beam/ml/inference/base.py       | 43 +++++-----
 3 files changed, 112 insertions(+), 24 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 74b3b835dbc..73bc6850b57 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -82,6 +82,7 @@
   Beam now supports data enrichment capabilities using SQL databases, with 
built-in support for:
   - Managed PostgreSQL, MySQL, and Microsoft SQL Server instances on CloudSQL
   - Unmanaged SQL database instances not hosted on CloudSQL (e.g., self-hosted 
or on-premises databases)
+* [Python] Added the `ReactiveThrottler` and `ThrottlingSignaler` classes to 
streamline throttling behavior in DoFns, expose throttling mechanisms for users 
([#35984](https://github.com/apache/beam/pull/35984))
 * Added a pipeline option to specify the processing timeout for a single 
element by any PTransform (Java/Python/Go) 
([#35174](https://github.com/apache/beam/issues/35174)).
   - When specified, the SDK harness automatically restarts if an element takes 
too long to process. Beam runner may then retry processing of the same work 
item.
   - Use the `--element_processing_timeout_minutes` option to reduce the chance 
of having stalled pipelines due to unexpected cases of slow processing, where 
slowness might not happen again if processing of the same element is retried.
diff --git a/sdks/python/apache_beam/io/components/adaptive_throttler.py 
b/sdks/python/apache_beam/io/components/adaptive_throttler.py
index f6290636073..3c22891ee8a 100644
--- a/sdks/python/apache_beam/io/components/adaptive_throttler.py
+++ b/sdks/python/apache_beam/io/components/adaptive_throttler.py
@@ -21,9 +21,32 @@
 
 # pytype: skip-file
 
+import logging
 import random
+import time
 
 from apache_beam.io.components import util
+from apache_beam.metrics.metric import Metrics
+
+_SECONDS_TO_MILLISECONDS = 1_000
+
+
+class ThrottlingSignaler(object):
+  """A class that handles signaling throttling of remote requests to the
+  SDK harness.
+  """
+  def __init__(self, namespace: str = ""):
+    self.throttling_metric = Metrics.counter(
+        namespace, "cumulativeThrottlingSeconds")
+
+  def signal_throttled(self, seconds: int):
+    """Signals to the runner that requests have been throttled for some amount
+    of time.
+
+    Args:
+      seconds: int, duration of throttling in seconds.
+    """
+    self.throttling_metric.inc(seconds)
 
 
 class AdaptiveThrottler(object):
@@ -94,3 +117,72 @@ class AdaptiveThrottler(object):
       now: int, time in ms since the epoch
     """
     self._successful_requests.add(now, 1)
+
+
+class ReactiveThrottler(AdaptiveThrottler):
+  """ A wrapper around the AdaptiveThrottler that also handles logging and
+  signaling throttling to the SDK harness using the provided namespace.
+
+  For usage, instantiate one instance of a ReactiveThrottler class for a
+  PTransform. When making remote calls to a service, preface that call with
+  the throttle() method to potentially pre-emptively throttle the request.
+  This will throttle future calls based on the failure rate of preceding calls,
+  with higher failure rates leading to longer periods of throttling to allow
+  system recovery. capture the timestamp of the attempted request, then execute
+  the request code. On a success, call successful_request(timestamp) to report
+  the success to the throttler. This flow looks like the following:
+  
+  def remote_call():
+    throttler.throttle()
+
+    try:
+      timestamp = time.time()
+      result = make_request()
+      throttler.successful_request(timestamp)
+      return result
+    except Exception as e:
+      # do any error handling you want to do
+      raise
+  """
+  def __init__(
+      self,
+      window_ms: int,
+      bucket_ms: int,
+      overload_ratio: float,
+      namespace: str = '',
+      throttle_delay_secs: int = 5):
+    """Initializes the ReactiveThrottler.
+
+    Args:
+      window_ms: int, length of history to consider, in ms, to set
+        throttling.
+      bucket_ms: int, granularity of time buckets that we store data in, in
+        ms.
+      overload_ratio: float, the target ratio between requests sent and
+        successful requests. This is "K" in the formula in
+        https://landing.google.com/sre/book/chapters/handling-overload.html.
+      namespace: str, the namespace to use for logging and signaling
+        throttling is occurring
+      throttle_delay_secs: int, the amount of time in seconds to wait
+        after preemptively throttled requests
+    """
+    self.throttling_signaler = ThrottlingSignaler(namespace=namespace)
+    self.logger = logging.getLogger(namespace)
+    self.throttle_delay_secs = throttle_delay_secs
+    super().__init__(
+        window_ms=window_ms, bucket_ms=bucket_ms, 
overload_ratio=overload_ratio)
+
+  def throttle(self):
+    """ Stops request code from advancing while the underlying
+    AdaptiveThrottler is signaling to preemptively throttle the request.
+    Automatically handles logging the throttling and signaling to the SDK
+    harness that the request is being throttled. This should be called in any
+    context where a call to a remote service is being contacted prior to the
+    call being performed.
+    """
+    while self.throttle_request(time.time() * _SECONDS_TO_MILLISECONDS):
+      self.logger.info(
+          "Delaying request for %d seconds due to previous failures",
+          self.throttle_delay_secs)
+      time.sleep(self.throttle_delay_secs)
+      self.throttling_signaler.signal_throttled(self.throttle_delay_secs)
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 4881fb74ef7..2e1c4963f11 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -55,8 +55,7 @@ from typing import TypeVar
 from typing import Union
 
 import apache_beam as beam
-from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
-from apache_beam.metrics.metric import Metrics
+from apache_beam.io.components.adaptive_throttler import ReactiveThrottler
 from apache_beam.utils import multi_process_shared
 from apache_beam.utils import retry
 from apache_beam.utils import shared
@@ -354,14 +353,16 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
       window_ms: int = 1 * _MILLISECOND_TO_SECOND,
       bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
       overload_ratio: float = 2):
-    """Initializes metrics tracking + an AdaptiveThrottler class for enabling
-    client-side throttling for remote calls to an inference service.
+    """Initializes a ReactiveThrottler class for enabling
+    client-side throttling for remote calls to an inference service. Also wraps
+    provided calls to the service with retry logic.
+    
     See https://s.apache.org/beam-client-side-throttling for more details
     on the configuration of the throttling and retry
     mechanics.
 
     Args:
-      namespace: the metrics and logging namespace 
+      namespace: the metrics and logging namespace
       num_retries: the maximum number of times to retry a request on retriable
         errors before failing
       throttle_delay_secs: the amount of time to throttle when the client-side
@@ -372,19 +373,18 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
       window_ms: length of history to consider, in ms, to set throttling.
       bucket_ms: granularity of time buckets that we store data in, in ms.
       overload_ratio: the target ratio between requests sent and successful
-        requests. This is "K" in the formula in 
+        requests. This is "K" in the formula in
         https://landing.google.com/sre/book/chapters/handling-overload.html.
     """
-    # Configure AdaptiveThrottler and throttling metrics for client-side
-    # throttling behavior.
-    self.throttled_secs = Metrics.counter(
-        namespace, "cumulativeThrottlingSeconds")
-    self.throttler = AdaptiveThrottler(
-        window_ms=window_ms, bucket_ms=bucket_ms, 
overload_ratio=overload_ratio)
+    # Configure ReactiveThrottler for client-side throttling behavior.
+    self.throttler = ReactiveThrottler(
+        window_ms=window_ms,
+        bucket_ms=bucket_ms,
+        overload_ratio=overload_ratio,
+        namespace=namespace,
+        throttle_delay_secs=throttle_delay_secs)
     self.logger = logging.getLogger(namespace)
-
     self.num_retries = num_retries
-    self.throttle_delay_secs = throttle_delay_secs
     self.retry_filter = retry_filter
 
   def __init_subclass__(cls):
@@ -434,12 +434,7 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
     Returns:
       An Iterable of Predictions.
     """
-    while self.throttler.throttle_request(time.time() * 
_MILLISECOND_TO_SECOND):
-      self.logger.info(
-          "Delaying request for %d seconds due to previous failures",
-          self.throttle_delay_secs)
-      time.sleep(self.throttle_delay_secs)
-      self.throttled_secs.inc(self.throttle_delay_secs)
+    self.throttler.throttle()
 
     try:
       req_time = time.time()
@@ -1642,7 +1637,7 @@ class _ModelRoutingStrategy():
 
 class _ModelStatus():
   """A class holding any metadata about a model required by RunInference.
-  
+
     Currently, this only includes whether or not the model is valid. Uses the
     model tag to map models to metadata.
   """
@@ -1656,7 +1651,7 @@ class _ModelStatus():
 
   def try_mark_current_model_invalid(self, min_model_life_seconds):
     """Mark the current model invalid.
-    
+
       Since we don't have sufficient information to say which model is being
       marked invalid, but there may be multiple active models, we will mark all
       models currently in use as inactive so that they all get reloaded. To
@@ -1678,7 +1673,7 @@ class _ModelStatus():
 
   def get_valid_tag(self, tag: str) -> str:
     """Takes in a proposed valid tag and returns a valid one.
-    
+
       Will always return a valid tag. If the passed in tag is valid, this
       function will simply return it, otherwise it will deterministically
       generate a new tag to use instead. The new tag will be the original tag
@@ -1747,7 +1742,7 @@ def load_model_status(
 
 class _SharedModelWrapper():
   """A router class to map incoming calls to the correct model.
-  
+
     This allows us to round robin calls to models sitting in different
     processes so that we can more efficiently use resources (e.g. GPUs).
   """

Reply via email to