This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 35969b34a37 Introduce more abstract handling of throttling signaling
(#35984)
35969b34a37 is described below
commit 35969b34a3734c495d4be9e0e99e399c7c9ed231
Author: Jack McCluskey <[email protected]>
AuthorDate: Tue Sep 2 16:43:18 2025 -0400
Introduce more abstract handling of throttling signaling (#35984)
* Introduce more abstract handling of throttling signaling
* linting
* gemini suggestions
* review suggestions
* Update sdks/python/apache_beam/io/components/adaptive_throttler.py
Co-authored-by: Danny McCormick <[email protected]>
* linting
---------
Co-authored-by: Danny McCormick <[email protected]>
---
CHANGES.md | 1 +
.../io/components/adaptive_throttler.py | 92 ++++++++++++++++++++++
sdks/python/apache_beam/ml/inference/base.py | 43 +++++-----
3 files changed, 112 insertions(+), 24 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 74b3b835dbc..73bc6850b57 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -82,6 +82,7 @@
Beam now supports data enrichment capabilities using SQL databases, with
built-in support for:
- Managed PostgreSQL, MySQL, and Microsoft SQL Server instances on CloudSQL
- Unmanaged SQL database instances not hosted on CloudSQL (e.g., self-hosted
or on-premises databases)
+* [Python] Added the `ReactiveThrottler` and `ThrottlingSignaler` classes to
streamline throttling behavior in DoFns, expose throttling mechanisms for users
([#35984](https://github.com/apache/beam/pull/35984))
* Added a pipeline option to specify the processing timeout for a single
element by any PTransform (Java/Python/Go)
([#35174](https://github.com/apache/beam/issues/35174)).
- When specified, the SDK harness automatically restarts if an element takes
too long to process. Beam runner may then retry processing of the same work
item.
- Use the `--element_processing_timeout_minutes` option to reduce the chance
of having stalled pipelines due to unexpected cases of slow processing, where
slowness might not happen again if processing of the same element is retried.
diff --git a/sdks/python/apache_beam/io/components/adaptive_throttler.py
b/sdks/python/apache_beam/io/components/adaptive_throttler.py
index f6290636073..3c22891ee8a 100644
--- a/sdks/python/apache_beam/io/components/adaptive_throttler.py
+++ b/sdks/python/apache_beam/io/components/adaptive_throttler.py
@@ -21,9 +21,32 @@
# pytype: skip-file
+import logging
import random
+import time
from apache_beam.io.components import util
+from apache_beam.metrics.metric import Metrics
+
+_SECONDS_TO_MILLISECONDS = 1_000
+
+
+class ThrottlingSignaler(object):
+ """A class that handles signaling throttling of remote requests to the
+ SDK harness.
+ """
+ def __init__(self, namespace: str = ""):
+ self.throttling_metric = Metrics.counter(
+ namespace, "cumulativeThrottlingSeconds")
+
+ def signal_throttled(self, seconds: int):
+ """Signals to the runner that requests have been throttled for some amount
+ of time.
+
+ Args:
+ seconds: int, duration of throttling in seconds.
+ """
+ self.throttling_metric.inc(seconds)
class AdaptiveThrottler(object):
@@ -94,3 +117,72 @@ class AdaptiveThrottler(object):
now: int, time in ms since the epoch
"""
self._successful_requests.add(now, 1)
+
+
+class ReactiveThrottler(AdaptiveThrottler):
+ """ A wrapper around the AdaptiveThrottler that also handles logging and
+ signaling throttling to the SDK harness using the provided namespace.
+
+ For usage, instantiate one instance of a ReactiveThrottler class for a
+ PTransform. When making remote calls to a service, preface that call with
+ the throttle() method to potentially pre-emptively throttle the request.
+ This will throttle future calls based on the failure rate of preceding calls,
+ with higher failure rates leading to longer periods of throttling to allow
+ system recovery. capture the timestamp of the attempted request, then execute
+ the request code. On a success, call successful_request(timestamp) to report
+ the success to the throttler. This flow looks like the following:
+
+ def remote_call():
+ throttler.throttle()
+
+ try:
+ timestamp = time.time()
+ result = make_request()
+ throttler.successful_request(timestamp)
+ return result
+ except Exception as e:
+ # do any error handling you want to do
+ raise
+ """
+ def __init__(
+ self,
+ window_ms: int,
+ bucket_ms: int,
+ overload_ratio: float,
+ namespace: str = '',
+ throttle_delay_secs: int = 5):
+ """Initializes the ReactiveThrottler.
+
+ Args:
+ window_ms: int, length of history to consider, in ms, to set
+ throttling.
+ bucket_ms: int, granularity of time buckets that we store data in, in
+ ms.
+ overload_ratio: float, the target ratio between requests sent and
+ successful requests. This is "K" in the formula in
+ https://landing.google.com/sre/book/chapters/handling-overload.html.
+ namespace: str, the namespace to use for logging and signaling
+ throttling is occurring
+ throttle_delay_secs: int, the amount of time in seconds to wait
+ after preemptively throttled requests
+ """
+ self.throttling_signaler = ThrottlingSignaler(namespace=namespace)
+ self.logger = logging.getLogger(namespace)
+ self.throttle_delay_secs = throttle_delay_secs
+ super().__init__(
+ window_ms=window_ms, bucket_ms=bucket_ms,
overload_ratio=overload_ratio)
+
+ def throttle(self):
+ """ Stops request code from advancing while the underlying
+ AdaptiveThrottler is signaling to preemptively throttle the request.
+ Automatically handles logging the throttling and signaling to the SDK
+ harness that the request is being throttled. This should be called in any
+ context where a call to a remote service is being contacted prior to the
+ call being performed.
+ """
+ while self.throttle_request(time.time() * _SECONDS_TO_MILLISECONDS):
+ self.logger.info(
+ "Delaying request for %d seconds due to previous failures",
+ self.throttle_delay_secs)
+ time.sleep(self.throttle_delay_secs)
+ self.throttling_signaler.signal_throttled(self.throttle_delay_secs)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 4881fb74ef7..2e1c4963f11 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -55,8 +55,7 @@ from typing import TypeVar
from typing import Union
import apache_beam as beam
-from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
-from apache_beam.metrics.metric import Metrics
+from apache_beam.io.components.adaptive_throttler import ReactiveThrottler
from apache_beam.utils import multi_process_shared
from apache_beam.utils import retry
from apache_beam.utils import shared
@@ -354,14 +353,16 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
overload_ratio: float = 2):
- """Initializes metrics tracking + an AdaptiveThrottler class for enabling
- client-side throttling for remote calls to an inference service.
+ """Initializes a ReactiveThrottler class for enabling
+ client-side throttling for remote calls to an inference service. Also wraps
+ provided calls to the service with retry logic.
+
See https://s.apache.org/beam-client-side-throttling for more details
on the configuration of the throttling and retry
mechanics.
Args:
- namespace: the metrics and logging namespace
+ namespace: the metrics and logging namespace
num_retries: the maximum number of times to retry a request on retriable
errors before failing
throttle_delay_secs: the amount of time to throttle when the client-side
@@ -372,19 +373,18 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
window_ms: length of history to consider, in ms, to set throttling.
bucket_ms: granularity of time buckets that we store data in, in ms.
overload_ratio: the target ratio between requests sent and successful
- requests. This is "K" in the formula in
+ requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
"""
- # Configure AdaptiveThrottler and throttling metrics for client-side
- # throttling behavior.
- self.throttled_secs = Metrics.counter(
- namespace, "cumulativeThrottlingSeconds")
- self.throttler = AdaptiveThrottler(
- window_ms=window_ms, bucket_ms=bucket_ms,
overload_ratio=overload_ratio)
+ # Configure ReactiveThrottler for client-side throttling behavior.
+ self.throttler = ReactiveThrottler(
+ window_ms=window_ms,
+ bucket_ms=bucket_ms,
+ overload_ratio=overload_ratio,
+ namespace=namespace,
+ throttle_delay_secs=throttle_delay_secs)
self.logger = logging.getLogger(namespace)
-
self.num_retries = num_retries
- self.throttle_delay_secs = throttle_delay_secs
self.retry_filter = retry_filter
def __init_subclass__(cls):
@@ -434,12 +434,7 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
Returns:
An Iterable of Predictions.
"""
- while self.throttler.throttle_request(time.time() *
_MILLISECOND_TO_SECOND):
- self.logger.info(
- "Delaying request for %d seconds due to previous failures",
- self.throttle_delay_secs)
- time.sleep(self.throttle_delay_secs)
- self.throttled_secs.inc(self.throttle_delay_secs)
+ self.throttler.throttle()
try:
req_time = time.time()
@@ -1642,7 +1637,7 @@ class _ModelRoutingStrategy():
class _ModelStatus():
"""A class holding any metadata about a model required by RunInference.
-
+
Currently, this only includes whether or not the model is valid. Uses the
model tag to map models to metadata.
"""
@@ -1656,7 +1651,7 @@ class _ModelStatus():
def try_mark_current_model_invalid(self, min_model_life_seconds):
"""Mark the current model invalid.
-
+
Since we don't have sufficient information to say which model is being
marked invalid, but there may be multiple active models, we will mark all
models currently in use as inactive so that they all get reloaded. To
@@ -1678,7 +1673,7 @@ class _ModelStatus():
def get_valid_tag(self, tag: str) -> str:
"""Takes in a proposed valid tag and returns a valid one.
-
+
Will always return a valid tag. If the passed in tag is valid, this
function will simply return it, otherwise it will deterministically
generate a new tag to use instead. The new tag will be the original tag
@@ -1747,7 +1742,7 @@ def load_model_status(
class _SharedModelWrapper():
"""A router class to map incoming calls to the correct model.
-
+
This allows us to round robin calls to models sitting in different
processes so that we can more efficiently use resources (e.g. GPUs).
"""