This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ab14c432dad Support for EnvoyRateLimiter in Beam Python SDK (#37135)
ab14c432dad is described below
commit ab14c432dada70f835156e78ef9582d8c0a1e906
Author: Tarun Annapareddy <[email protected]>
AuthorDate: Mon Jan 5 18:51:54 2026 +0530
Support for EnvoyRateLimiter in Beam Python SDK (#37135)
* Support for EnvoyRateLimiter in Apache Beam
* fix format issues
* fix test formatting
* Fix test and syntax
* fix lint
* Add dependency based on python version
* revert setup to separete pr
* fix lint
* fix formatting
* resolve comments
---
.../apache_beam/examples/rate_limiter_simple.py | 93 +++++++++
.../apache_beam/io/components/rate_limiter.py | 226 +++++++++++++++++++++
.../apache_beam/io/components/rate_limiter_test.py | 143 +++++++++++++
3 files changed, 462 insertions(+)
diff --git a/sdks/python/apache_beam/examples/rate_limiter_simple.py
b/sdks/python/apache_beam/examples/rate_limiter_simple.py
new file mode 100644
index 00000000000..ea469006f2b
--- /dev/null
+++ b/sdks/python/apache_beam/examples/rate_limiter_simple.py
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A simple example demonstrating usage of the EnvoyRateLimiter in a Beam
+pipeline.
+"""
+
+import argparse
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.utils import shared
+
+
+class SampleApiDoFn(beam.DoFn):
+ """A DoFn that simulates calling an external API with rate limiting."""
+ def __init__(self, rls_address, domain, descriptors):
+ self.rls_address = rls_address
+ self.domain = domain
+ self.descriptors = descriptors
+ self._shared = shared.Shared()
+ self.rate_limiter = None
+
+ def setup(self):
+ # Initialize the rate limiter in setup()
+ # We use shared.Shared() to ensure only one RateLimiter instance is created
+ # per worker and shared across threads.
+ def init_limiter():
+ logging.info("Connecting to Envoy RLS at %s", self.rls_address)
+ return EnvoyRateLimiter(
+ service_address=self.rls_address,
+ domain=self.domain,
+ descriptors=self.descriptors,
+ namespace='example_pipeline')
+
+ self.rate_limiter = self._shared.acquire(init_limiter)
+
+ def process(self, element):
+ self.rate_limiter.throttle()
+
+ # Process the element mock API call
+ logging.info("Processing element: %s", element)
+ time.sleep(0.1)
+ yield element
+
+
+def parse_known_args(argv):
+ """Parses args for the workflow."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--rls_address',
+ default='localhost:8081',
+ help='Address of the Envoy Rate Limit Service')
+ return parser.parse_known_args(argv)
+
+
+def run(argv=None):
+ known_args, pipeline_args = parse_known_args(argv)
+ pipeline_options = PipelineOptions(pipeline_args)
+
+ with beam.Pipeline(options=pipeline_options) as p:
+ _ = (
+ p
+ | 'Create' >> beam.Create(range(100))
+ | 'RateLimit' >> beam.ParDo(
+ SampleApiDoFn(
+ rls_address=known_args.rls_address,
+ domain="mongo_cps",
+ descriptors=[{
+ "database": "users"
+ }])))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ run()
diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py
b/sdks/python/apache_beam/io/components/rate_limiter.py
new file mode 100644
index 00000000000..3de39ddd935
--- /dev/null
+++ b/sdks/python/apache_beam/io/components/rate_limiter.py
@@ -0,0 +1,226 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Rate Limiter classes for controlling access to external resources.
+"""
+
+import abc
+import logging
+import math
+import random
+import threading
+import time
+from typing import Dict
+from typing import List
+
+import grpc
+from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import
RateLimitDescriptor
+from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import
RateLimitDescriptorEntry
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitRequest
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
+
+from apache_beam.io.components import adaptive_throttler
+from apache_beam.metrics import Metrics
+
+_LOGGER = logging.getLogger(__name__)
+
+_RPC_MAX_RETRIES = 5
+_RPC_RETRY_DELAY_SECONDS = 10
+
+
+class RateLimiter(abc.ABC):
+ """Abstract base class for RateLimiters."""
+ def __init__(self, namespace: str = ""):
+ # Metrics collected from the RateLimiter
+ # Metric updates are thread safe
+ self.throttling_signaler = adaptive_throttler.ThrottlingSignaler(
+ namespace=namespace)
+ self.requests_counter = Metrics.counter(namespace,
'RatelimitRequestsTotal')
+ self.requests_allowed = Metrics.counter(
+ namespace, 'RatelimitRequestsAllowed')
+ self.requests_throttled = Metrics.counter(
+ namespace, 'RatelimitRequestsThrottled')
+ self.rpc_errors = Metrics.counter(namespace, 'RatelimitRpcErrors')
+ self.rpc_retries = Metrics.counter(namespace, 'RatelimitRpcRetries')
+ self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
+
+ @abc.abstractmethod
+ def throttle(self, **kwargs) -> bool:
+ """Check if request should be throttled.
+
+ Args:
+ **kwargs: Keyword arguments specific to the RateLimiter implementation.
+
+ Returns:
+ bool: True if the request is allowed, False if retries exceeded.
+
+ Raises:
+ Exception: If an underlying infrastructure error occurs (e.g. RPC
+ failure).
+ """
+ pass
+
+
+class EnvoyRateLimiter(RateLimiter):
+ """
+ Rate limiter implementation that uses an external Envoy Rate Limit Service.
+ """
+ def __init__(
+ self,
+ service_address: str,
+ domain: str,
+ descriptors: List[Dict[str, str]],
+ timeout: float = 5.0,
+ block_until_allowed: bool = True,
+ retries: int = 3,
+ namespace: str = ""):
+ """
+ Args:
+ service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
+ domain: The rate limit domain.
+ descriptors: List of descriptors (key-value pairs).
+ retries: Number of retries to attempt if rate limited, respected only if
+ block_until_allowed is False.
+ timeout: gRPC timeout in seconds.
+ block_until_allowed: If enabled blocks until RateLimiter gets
+ the token.
+ namespace: the namespace to use for logging and signaling
+ throttling is occurring.
+ """
+ super().__init__(namespace=namespace)
+
+ self.service_address = service_address
+ self.domain = domain
+ self.descriptors = descriptors
+ self.retries = retries
+ self.timeout = timeout
+ self.block_until_allowed = block_until_allowed
+ self._stub = None
+ self._lock = threading.Lock()
+
+ class RateLimitServiceStub(object):
+ """
+ Wrapper for gRPC stub to be compatible with envoy_data_plane messages.
+
+ The envoy-data-plane package uses 'betterproto' which generates async stubs
+ for 'grpclib'. As Beam uses standard synchronous 'grpcio',
+ RateLimitServiceStub is a bridge class to use the betterproto Message types
+ (RateLimitRequest) with a standard grpcio Channel.
+ """
+ def __init__(self, channel):
+ self.ShouldRateLimit = channel.unary_unary(
+ '/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit',
+ request_serializer=RateLimitRequest.SerializeToString,
+ response_deserializer=RateLimitResponse.FromString,
+ )
+
+ def init_connection(self):
+ if self._stub is None:
+ # Acquire lock to safegaurd againest multiple DoFn threads sharing the
+ # same RateLimiter instance, which is the case when using Shared().
+ with self._lock:
+ if self._stub is None:
+ channel = grpc.insecure_channel(self.service_address)
+ self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
+
+ def throttle(self, hits_added: int = 1) -> bool:
+ """Calls the Envoy RLS to check for rate limits.
+
+ Args:
+ hits_added: Number of hits to add to the rate limit.
+
+ Returns:
+ bool: True if the request is allowed, False if retries exceeded.
+ """
+ self.init_connection()
+
+ # execute thread-safe gRPC call
+ # Convert descriptors to proto format
+ proto_descriptors = []
+ for d in self.descriptors:
+ entries = []
+ for k, v in d.items():
+ entries.append(RateLimitDescriptorEntry(key=k, value=v))
+ proto_descriptors.append(RateLimitDescriptor(entries=entries))
+
+ request = RateLimitRequest(
+ domain=self.domain,
+ descriptors=proto_descriptors,
+ hits_addend=hits_added)
+
+ self.requests_counter.inc()
+ attempt = 0
+ throttled = False
+ while True:
+ if not self.block_until_allowed and attempt > self.retries:
+ break
+
+ # retry loop
+ for retry_attempt in range(_RPC_MAX_RETRIES):
+ try:
+ start_time = time.time()
+ response = self._stub.ShouldRateLimit(request, timeout=self.timeout)
+ self.rpc_latency.update(int((time.time() - start_time) * 1000))
+ break
+ except grpc.RpcError as e:
+ if retry_attempt == _RPC_MAX_RETRIES - 1:
+ _LOGGER.error(
+ "[EnvoyRateLimiter] ratelimit service call failed: %s", e)
+ self.rpc_errors.inc()
+ raise e
+ self.rpc_retries.inc()
+ _LOGGER.warning(
+ "[EnvoyRateLimiter] ratelimit service call failed, retrying: %s",
+ e)
+ time.sleep(_RPC_RETRY_DELAY_SECONDS)
+
+ if response.overall_code == RateLimitResponseCode.OK:
+ self.requests_allowed.inc()
+ throttled = True
+ break
+ elif response.overall_code == RateLimitResponseCode.OVER_LIMIT:
+ self.requests_throttled.inc()
+ # Ratelimit exceeded, sleep for duration until reset and retry
+ # multiple rules can be set in the RLS config, so we need to find the
+ # max duration
+ sleep_s = 0.0
+ if response.statuses:
+ for status in response.statuses:
+ if status.code == RateLimitResponseCode.OVER_LIMIT:
+ dur = status.duration_until_reset
+ # duration_until_reset is converted to timedelta by betterproto
+ val = dur.total_seconds()
+ if val > sleep_s:
+ sleep_s = val
+
+ # Add 1% additive jitter to prevent thundering herd
+ jitter = random.uniform(0, 0.01 * sleep_s)
+ sleep_s += jitter
+
+ _LOGGER.warning("[EnvoyRateLimiter] Throttled for %s seconds", sleep_s)
+ # signal throttled time to backend
+ self.throttling_signaler.signal_throttled(math.ceil(sleep_s))
+ time.sleep(sleep_s)
+ attempt += 1
+ else:
+ _LOGGER.error(
+ "[EnvoyRateLimiter] Unknown code from RLS: %s",
+ response.overall_code)
+ break
+ return throttled
diff --git a/sdks/python/apache_beam/io/components/rate_limiter_test.py
b/sdks/python/apache_beam/io/components/rate_limiter_test.py
new file mode 100644
index 00000000000..7c3e7b82aad
--- /dev/null
+++ b/sdks/python/apache_beam/io/components/rate_limiter_test.py
@@ -0,0 +1,143 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import timedelta
+from unittest import mock
+
+import grpc
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
+from envoy_data_plane.envoy.service.ratelimit.v3 import
RateLimitResponseDescriptorStatus
+
+from apache_beam.io.components import rate_limiter
+
+
+class EnvoyRateLimiterTest(unittest.TestCase):
+ def setUp(self):
+ self.service_address = 'localhost:8081'
+ self.domain = 'test_domain'
+ self.descriptors = [{'key': 'value'}]
+ self.limiter = rate_limiter.EnvoyRateLimiter(
+ self.service_address,
+ self.domain,
+ self.descriptors,
+ timeout=0.1, # Fast timeout for tests
+ block_until_allowed=False,
+ retries=2,
+ namespace='test_namespace')
+
+ @mock.patch('grpc.insecure_channel')
+ def test_throttle_allowed(self, mock_channel):
+ # Mock successful OK response
+ mock_stub = mock.Mock()
+ mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
+ mock_stub.ShouldRateLimit.return_value = mock_response
+
+ # Inject mock stub
+ self.limiter._stub = mock_stub
+
+ throttled = self.limiter.throttle()
+
+ self.assertTrue(throttled)
+ mock_stub.ShouldRateLimit.assert_called_once()
+
+ @mock.patch('grpc.insecure_channel')
+ def test_throttle_over_limit_retries_exceeded(self, mock_channel):
+ # Mock OVER_LIMIT response
+ mock_stub = mock.Mock()
+ mock_response = RateLimitResponse(
+ overall_code=RateLimitResponseCode.OVER_LIMIT)
+ mock_stub.ShouldRateLimit.return_value = mock_response
+
+ self.limiter._stub = mock_stub
+ # block_until_allowed is False, so it should eventually return False
+
+ # We mock time.sleep to run fast
+ with mock.patch('time.sleep'):
+ throttled = self.limiter.throttle()
+
+ self.assertFalse(throttled)
+ # Should be called 1 (initial) + 2 (retries) + 1 (last check > retries
+ # logic depends on loop)
+ # Logic: attempt starts at 0.
+ # Loop 1: attempt 0. status OVER_LIMIT. sleep. attempt becomes 1.
+ # Loop 2: attempt 1. status OVER_LIMIT. sleep. attempt becomes 2.
+ # Loop 3: attempt 2. status OVER_LIMIT. sleep. attempt becomes 3.
+ # Loop 4: attempt 3 > retries(2). Break.
+ # Total calls: 3
+ self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
+
+ @mock.patch('grpc.insecure_channel')
+ def test_throttle_rpc_error_retry(self, mock_channel):
+ # Mock RpcError then Success
+ mock_stub = mock.Mock()
+ mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
+
+ # Side effect: Error, Error, Success
+ error = grpc.RpcError()
+ mock_stub.ShouldRateLimit.side_effect = [error, error, mock_response]
+
+ self.limiter._stub = mock_stub
+
+ with mock.patch('time.sleep'):
+ throttled = self.limiter.throttle()
+
+ self.assertTrue(throttled)
+ self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
+
+ @mock.patch('grpc.insecure_channel')
+ def test_throttle_rpc_error_fail(self, mock_channel):
+ # Mock Persistent RpcError
+ mock_stub = mock.Mock()
+ error = grpc.RpcError()
+ mock_stub.ShouldRateLimit.side_effect = error
+
+ self.limiter._stub = mock_stub
+
+ with mock.patch('time.sleep'):
+ with self.assertRaises(grpc.RpcError):
+ self.limiter.throttle()
+
+ # The inner loop tries 5 times for connection errors
+ self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5)
+
+ @mock.patch('grpc.insecure_channel')
+ @mock.patch('random.uniform', return_value=0.0)
+ def test_extract_duration_from_response(self, mock_random, mock_channel):
+ # Mock OVER_LIMIT with specific duration
+ mock_stub = mock.Mock()
+
+ # Valid until 5 seconds
+ status = RateLimitResponseDescriptorStatus(
+ code=RateLimitResponseCode.OVER_LIMIT,
+ duration_until_reset=timedelta(seconds=5))
+ mock_response = RateLimitResponse(
+ overall_code=RateLimitResponseCode.OVER_LIMIT, statuses=[status])
+
+ mock_stub.ShouldRateLimit.return_value = mock_response
+ self.limiter._stub = mock_stub
+ self.limiter.retries = 0 # Single attempt
+
+ with mock.patch('time.sleep') as mock_sleep:
+ self.limiter.throttle()
+ # Should sleep for 5 seconds (jitter is 0.0)
+ mock_sleep.assert_called_with(5.0)
+
+
+if __name__ == '__main__':
+ unittest.main()