This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch release-2.71
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.71 by this push:
     new 0c4d81bc49e Cherry Pick RateLimiter SDK changes to Beam 2.71 release 
(#37306)
0c4d81bc49e is described below

commit 0c4d81bc49e5244ef324baead1f6a9dd430e099f
Author: Tarun Annapareddy <[email protected]>
AuthorDate: Wed Jan 14 22:00:47 2026 +0530

    Cherry Pick RateLimiter SDK changes to Beam 2.71 release (#37306)
    
    * Support for RateLimiter in Beam Remote Model Handler  (#37218)
    
    * Support for EnvoyRateLimiter in Apache Beam
    
    * fix format issues
    
    * fix test formatting
    
    * Fix test and syntax
    
    * fix lint
    
    * Add dependency based on python version
    
    * revert setup to separete pr
    
    * fix lint
    
    * fix formatting
    
    * resolve comments
    
    * Support Ratelimiter through RemoteModelHandler
    
    * fix lint
    
    * fix lint
    
    * fix comments
    
    * Add custom RateLimited Exception
    
    * fix doc
    
    * fix test
    
    * fix lint
    
    * update RateLimiter execution function name (#37287)
    
    * Catch breaking import error (#37295)
    
    * Catch Import Error
    
    * import order
---
 .../examples/inference/rate_limiter_vertex_ai.py   | 85 ++++++++++++++++++++++
 .../apache_beam/examples/rate_limiter_simple.py    |  2 +-
 .../apache_beam/io/components/rate_limiter.py      | 44 +++++++++--
 .../apache_beam/io/components/rate_limiter_test.py | 24 +++---
 sdks/python/apache_beam/ml/inference/base.py       | 30 +++++++-
 sdks/python/apache_beam/ml/inference/base_test.py  | 61 ++++++++++++++++
 6 files changed, 225 insertions(+), 21 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py 
b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py
new file mode 100644
index 00000000000..11ec02fbd54
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A simple example demonstrating usage of the EnvoyRateLimiter with Vertex AI.
+"""
+
+import argparse
+import logging
+
+import apache_beam as beam
+from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.vertex_ai_inference import 
VertexAIModelHandlerJSON
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+
+
+def run(argv=None):
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--project',
+      dest='project',
+      help='The Google Cloud project ID for Vertex AI.')
+  parser.add_argument(
+      '--location',
+      dest='location',
+      help='The Google Cloud location (e.g. us-central1) for Vertex AI.')
+  parser.add_argument(
+      '--endpoint_id',
+      dest='endpoint_id',
+      help='The ID of the Vertex AI endpoint.')
+  parser.add_argument(
+      '--rls_address',
+      dest='rls_address',
+      help='The address of the Envoy Rate Limit Service (e.g. 
localhost:8081).')
+
+  known_args, pipeline_args = parser.parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = True
+
+  # Initialize the EnvoyRateLimiter
+  rate_limiter = EnvoyRateLimiter(
+      service_address=known_args.rls_address,
+      domain="mongo_cps",
+      descriptors=[{
+          "database": "users"
+      }],
+      namespace='example_pipeline')
+
+  # Initialize the VertexAIModelHandler with the rate limiter
+  model_handler = VertexAIModelHandlerJSON(
+      endpoint_id=known_args.endpoint_id,
+      project=known_args.project,
+      location=known_args.location,
+      rate_limiter=rate_limiter)
+
+  # Input features for the model
+  features = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
+              [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]
+
+  with beam.Pipeline(options=pipeline_options) as p:
+    _ = (
+        p
+        | 'CreateInputs' >> beam.Create(features)
+        | 'RunInference' >> RunInference(model_handler)
+        | 'PrintPredictions' >> beam.Map(logging.info))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/examples/rate_limiter_simple.py 
b/sdks/python/apache_beam/examples/rate_limiter_simple.py
index ea469006f2b..8cdf1166aad 100644
--- a/sdks/python/apache_beam/examples/rate_limiter_simple.py
+++ b/sdks/python/apache_beam/examples/rate_limiter_simple.py
@@ -53,7 +53,7 @@ class SampleApiDoFn(beam.DoFn):
     self.rate_limiter = self._shared.acquire(init_limiter)
 
   def process(self, element):
-    self.rate_limiter.throttle()
+    self.rate_limiter.allow()
 
     # Process the element mock API call
     logging.info("Processing element: %s", element)
diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py 
b/sdks/python/apache_beam/io/components/rate_limiter.py
index 3de39ddd935..2dc8a5340fd 100644
--- a/sdks/python/apache_beam/io/components/rate_limiter.py
+++ b/sdks/python/apache_beam/io/components/rate_limiter.py
@@ -61,8 +61,13 @@ class RateLimiter(abc.ABC):
     self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
 
   @abc.abstractmethod
-  def throttle(self, **kwargs) -> bool:
-    """Check if request should be throttled.
+  def allow(self, **kwargs) -> bool:
+    """Applies rate limiting to the request.
+
+    This method checks if the request is permitted by the rate limiting policy.
+    Depending on the implementation and configuration, it may block (sleep)
+    until the request is allowed, or return false if the rate limit retry is
+    exceeded.
 
     Args:
       **kwargs: Keyword arguments specific to the RateLimiter implementation.
@@ -78,8 +83,12 @@ class RateLimiter(abc.ABC):
 
 
 class EnvoyRateLimiter(RateLimiter):
-  """
-  Rate limiter implementation that uses an external Envoy Rate Limit Service.
+  """Rate limiter implementation that uses an external Envoy Rate Limit 
Service.
+
+  This limiter connects to a gRPC Envoy Rate Limit Service (RLS) to determine
+  whether a request should be allowed. It supports defining a domain and a
+  list of descriptors that correspond to the rate limit configuration in the
+  RLS.
   """
   def __init__(
       self,
@@ -89,7 +98,7 @@ class EnvoyRateLimiter(RateLimiter):
       timeout: float = 5.0,
       block_until_allowed: bool = True,
       retries: int = 3,
-      namespace: str = ""):
+      namespace: str = ''):
     """
     Args:
       service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
@@ -139,8 +148,16 @@ class EnvoyRateLimiter(RateLimiter):
           channel = grpc.insecure_channel(self.service_address)
           self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
 
-  def throttle(self, hits_added: int = 1) -> bool:
-    """Calls the Envoy RLS to check for rate limits.
+  def allow(self, hits_added: int = 1) -> bool:
+    """Calls the Envoy RLS to apply rate limits.
+
+    Sends a rate limit request to the configured Envoy Rate Limit Service.
+    If 'block_until_allowed' is True, this method will sleep and retry
+    if the limit is exceeded, effectively blocking until the request is
+    permitted.
+
+    If 'block_until_allowed' is False, it will return False after the retry
+    limit is exceeded.
 
     Args:
       hits_added: Number of hits to add to the rate limit.
@@ -224,3 +241,16 @@ class EnvoyRateLimiter(RateLimiter):
             response.overall_code)
         break
     return throttled
+
+  def __getstate__(self):
+    state = self.__dict__.copy()
+    if '_lock' in state:
+      del state['_lock']
+    if '_stub' in state:
+      del state['_stub']
+    return state
+
+  def __setstate__(self, state):
+    self.__dict__.update(state)
+    self._lock = threading.Lock()
+    self._stub = None
diff --git a/sdks/python/apache_beam/io/components/rate_limiter_test.py 
b/sdks/python/apache_beam/io/components/rate_limiter_test.py
index 7c3e7b82aad..24d30a1c5c9 100644
--- a/sdks/python/apache_beam/io/components/rate_limiter_test.py
+++ b/sdks/python/apache_beam/io/components/rate_limiter_test.py
@@ -42,7 +42,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
         namespace='test_namespace')
 
   @mock.patch('grpc.insecure_channel')
-  def test_throttle_allowed(self, mock_channel):
+  def test_allow_success(self, mock_channel):
     # Mock successful OK response
     mock_stub = mock.Mock()
     mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -51,13 +51,13 @@ class EnvoyRateLimiterTest(unittest.TestCase):
     # Inject mock stub
     self.limiter._stub = mock_stub
 
-    throttled = self.limiter.throttle()
+    allowed = self.limiter.allow()
 
-    self.assertTrue(throttled)
+    self.assertTrue(allowed)
     mock_stub.ShouldRateLimit.assert_called_once()
 
   @mock.patch('grpc.insecure_channel')
-  def test_throttle_over_limit_retries_exceeded(self, mock_channel):
+  def test_allow_over_limit_retries_exceeded(self, mock_channel):
     # Mock OVER_LIMIT response
     mock_stub = mock.Mock()
     mock_response = RateLimitResponse(
@@ -69,9 +69,9 @@ class EnvoyRateLimiterTest(unittest.TestCase):
 
     # We mock time.sleep to run fast
     with mock.patch('time.sleep'):
-      throttled = self.limiter.throttle()
+      allowed = self.limiter.allow()
 
-    self.assertFalse(throttled)
+    self.assertFalse(allowed)
     # Should be called 1 (initial) + 2 (retries) + 1 (last check > retries
     # logic depends on loop)
     # Logic: attempt starts at 0.
@@ -83,7 +83,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
     self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
 
   @mock.patch('grpc.insecure_channel')
-  def test_throttle_rpc_error_retry(self, mock_channel):
+  def test_allow_rpc_error_retry(self, mock_channel):
     # Mock RpcError then Success
     mock_stub = mock.Mock()
     mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -95,13 +95,13 @@ class EnvoyRateLimiterTest(unittest.TestCase):
     self.limiter._stub = mock_stub
 
     with mock.patch('time.sleep'):
-      throttled = self.limiter.throttle()
+      allowed = self.limiter.allow()
 
-    self.assertTrue(throttled)
+    self.assertTrue(allowed)
     self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
 
   @mock.patch('grpc.insecure_channel')
-  def test_throttle_rpc_error_fail(self, mock_channel):
+  def test_allow_rpc_error_fail(self, mock_channel):
     # Mock Persistent RpcError
     mock_stub = mock.Mock()
     error = grpc.RpcError()
@@ -111,7 +111,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
 
     with mock.patch('time.sleep'):
       with self.assertRaises(grpc.RpcError):
-        self.limiter.throttle()
+        self.limiter.allow()
 
     # The inner loop tries 5 times for connection errors
     self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5)
@@ -134,7 +134,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
     self.limiter.retries = 0  # Single attempt
 
     with mock.patch('time.sleep') as mock_sleep:
-      self.limiter.throttle()
+      self.limiter.allow()
       # Should sleep for 5 seconds (jitter is 0.0)
       mock_sleep.assert_called_with(5.0)
 
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index d79565ee24d..e0f870669f7 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -60,6 +60,11 @@ from apache_beam.utils import multi_process_shared
 from apache_beam.utils import retry
 from apache_beam.utils import shared
 
+try:
+  from apache_beam.io.components.rate_limiter import RateLimiter
+except ImportError:
+  RateLimiter = None
+
 try:
   # pylint: disable=wrong-import-order, wrong-import-position
   import resource
@@ -102,6 +107,11 @@ PredictionResult.inference.__doc__ = """Results for the 
inference on the model
 PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
 
 
+class RateLimitExceeded(RuntimeError):
+  """RateLimit Exceeded to process a batch of requests."""
+  pass
+
+
 class ModelMetadata(NamedTuple):
   model_id: str
   model_name: str
@@ -349,7 +359,8 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
       *,
       window_ms: int = 1 * _MILLISECOND_TO_SECOND,
       bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
-      overload_ratio: float = 2):
+      overload_ratio: float = 2,
+      rate_limiter: Optional[RateLimiter] = None):
     """Initializes a ReactiveThrottler class for enabling
     client-side throttling for remote calls to an inference service. Also wraps
     provided calls to the service with retry logic.
@@ -372,6 +383,7 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
       overload_ratio: the target ratio between requests sent and successful
         requests. This is "K" in the formula in
         https://landing.google.com/sre/book/chapters/handling-overload.html.
+      rate_limiter: A RateLimiter object for setting a global rate limit.
     """
     # Configure ReactiveThrottler for client-side throttling behavior.
     self.throttler = ReactiveThrottler(
@@ -383,6 +395,9 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
     self.logger = logging.getLogger(namespace)
     self.num_retries = num_retries
     self.retry_filter = retry_filter
+    self._rate_limiter = rate_limiter
+    self._shared_rate_limiter = None
+    self._shared_handle = shared.Shared()
 
   def __init_subclass__(cls):
     if cls.load_model is not RemoteModelHandler.load_model:
@@ -431,6 +446,19 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
     Returns:
       An Iterable of Predictions.
     """
+    if self._rate_limiter:
+      if self._shared_rate_limiter is None:
+
+        def init_limiter():
+          return self._rate_limiter
+
+        self._shared_rate_limiter = self._shared_handle.acquire(init_limiter)
+
+      if not self._shared_rate_limiter.allow(hits_added=len(batch)):
+        raise RateLimitExceeded(
+            "Rate Limit Exceeded, "
+            "Could not process this batch.")
+
     self.throttler.throttle()
 
     try:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 574e71de89c..381bf545660 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -2071,6 +2071,67 @@ class RunInferenceRemoteTest(unittest.TestCase):
             responses.append(model.predict(example))
           return responses
 
+  def test_run_inference_with_rate_limiter(self):
+    class FakeRateLimiter(base.RateLimiter):
+      def __init__(self):
+        super().__init__(namespace='test_namespace')
+
+      def allow(self, hits_added=1):
+        self.requests_counter.inc()
+        return True
+
+    limiter = FakeRateLimiter()
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5]
+
+      class ConcreteRemoteModelHandler(base.RemoteModelHandler):
+        def create_client(self):
+          return FakeModel()
+
+        def request(self, batch, model, inference_args=None):
+          return [model.predict(example) for example in batch]
+
+      model_handler = ConcreteRemoteModelHandler(
+          rate_limiter=limiter, namespace='test_namespace')
+
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(model_handler)
+
+      expected = [2, 6]
+      assert_that(actual, equal_to(expected))
+
+      result = pipeline.run()
+      result.wait_until_finish()
+
+      metrics_filter = MetricsFilter().with_name(
+          'RatelimitRequestsTotal').with_namespace('test_namespace')
+      metrics = result.metrics().query(metrics_filter)
+      self.assertGreaterEqual(metrics['counters'][0].committed, 0)
+
+  def test_run_inference_with_rate_limiter_exceeded(self):
+    class FakeRateLimiter(base.RateLimiter):
+      def __init__(self):
+        super().__init__(namespace='test_namespace')
+
+      def allow(self, hits_added=1):
+        return False
+
+    class ConcreteRemoteModelHandler(base.RemoteModelHandler):
+      def create_client(self):
+        return FakeModel()
+
+      def request(self, batch, model, inference_args=None):
+        return [model.predict(example) for example in batch]
+
+    model_handler = ConcreteRemoteModelHandler(
+        rate_limiter=FakeRateLimiter(),
+        namespace='test_namespace',
+        num_retries=0)
+
+    with self.assertRaises(base.RateLimitExceeded):
+      model_handler.run_inference([1], FakeModel())
+
 
 if __name__ == '__main__':
   unittest.main()

Reply via email to