This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c4a45c43fb [FLINK-38561][python] Add retry support for async function 
in Python DataStream API
2c4a45c43fb is described below

commit 2c4a45c43fbd7b489228e4f5550452025e4efb3b
Author: Dian Fu <[email protected]>
AuthorDate: Tue Oct 28 09:22:07 2025 +0800

    [FLINK-38561][python] Add retry support for async function in Python 
DataStream API
    
    This closes #27164.
---
 flink-python/pyflink/datastream/__init__.py        |   8 +-
 .../pyflink/datastream/async_data_stream.py        |  44 +++++++-
 .../pyflink/datastream/async_retry_strategies.py   | 121 +++++++++++++++++++++
 flink-python/pyflink/datastream/functions.py       |  86 ++++++++++++++-
 .../datastream/tests/test_async_function.py        |  50 ++++++++-
 .../datastream/process/async_function/operation.py | 107 +++++++++++++++++-
 6 files changed, 399 insertions(+), 17 deletions(-)

diff --git a/flink-python/pyflink/datastream/__init__.py 
b/flink-python/pyflink/datastream/__init__.py
index 389116decb2..d5bb5576ff8 100644
--- a/flink-python/pyflink/datastream/__init__.py
+++ b/flink-python/pyflink/datastream/__init__.py
@@ -82,6 +82,8 @@ Functions used to transform a :class:`DataStream` into 
another :class:`DataStrea
     - :class:`KeyedBroadcastProcessFunction`:
       A function to be applied to a :class:`BroadcastConnectedStream` that 
connects
       :class:`BroadcastStream`, i.e. a stream with broadcast state, with a 
:class:`KeyedStream`.
+    - :class:`AsyncFunction`:
+      A function to trigger Async I/O operation.
     - :class:`RuntimeContext`:
       Contains information about the context in which functions are executed. 
Each
       parallel instance of the function will have a context through which it 
can access static
@@ -270,7 +272,7 @@ from pyflink.datastream.functions import (MapFunction, 
CoMapFunction, FlatMapFun
                                           KeyedCoProcessFunction, 
AggregateFunction, WindowFunction,
                                           ProcessWindowFunction, 
BroadcastProcessFunction,
                                           KeyedBroadcastProcessFunction, 
AsyncFunction,
-                                          ResultFuture)
+                                          ResultFuture, AsyncRetryPredicate, 
AsyncRetryStrategy)
 from pyflink.datastream.slot_sharing_group import SlotSharingGroup, MemorySize
 from pyflink.datastream.state_backend import (StateBackend, CustomStateBackend,
                                               PredefinedOptions, 
HashMapStateBackend,
@@ -343,5 +345,7 @@ __all__ = [
     'SlotSharingGroup',
     'MemorySize',
     'OutputTag',
-    'ResultFuture'
+    'ResultFuture',
+    'AsyncRetryPredicate',
+    'AsyncRetryStrategy'
 ]
diff --git a/flink-python/pyflink/datastream/async_data_stream.py 
b/flink-python/pyflink/datastream/async_data_stream.py
index 6fc6817e892..0e14533c4d8 100644
--- a/flink-python/pyflink/datastream/async_data_stream.py
+++ b/flink-python/pyflink/datastream/async_data_stream.py
@@ -18,8 +18,9 @@
 import inspect
 
 from pyflink.common import Time, TypeInformation
+from pyflink.datastream import async_retry_strategies
 from pyflink.datastream.data_stream import DataStream, 
_get_one_input_stream_operator
-from pyflink.datastream.functions import AsyncFunctionDescriptor, AsyncFunction
+from pyflink.datastream.functions import AsyncFunctionDescriptor, 
AsyncFunction, AsyncRetryStrategy
 from pyflink.java_gateway import get_gateway
 from pyflink.util.java_utils import get_j_env_configuration
 
@@ -47,14 +48,39 @@ class AsyncDataStream(object):
         :param output_type: The output data type.
         :return: The transformed DataStream.
         """
-        AsyncDataStream._validate(data_stream, async_function)
+        return AsyncDataStream.unordered_wait_with_retry(
+            data_stream, async_function, timeout, 
async_retry_strategies.NO_RETRY_STRATEGY,
+            capacity, output_type)
+
+    @staticmethod
+    def unordered_wait_with_retry(
+            data_stream: DataStream,
+            async_function: AsyncFunction,
+            timeout: Time,
+            async_retry_strategy: AsyncRetryStrategy,
+            capacity: int = 100,
+            output_type: TypeInformation = None) -> 'DataStream':
+        """
+        Adds an async function with an AsyncRetryStrategy to support retry of 
AsyncFunction to the
+        data stream. The order of output stream records may be reordered.
+
+        :param data_stream: The input data stream.
+        :param async_function: The async function.
+        :param timeout: The timeout for the asynchronous operation to complete.
+        :param async_retry_strategy: The strategy of reattempt async i/o 
operation that can be
+                                     triggered
+        :param capacity: The max number of async i/o operation that can be 
triggered.
+        :param output_type: The output data type.
+        :return: The transformed DataStream.
+        """
+        AsyncDataStream._validate(data_stream, async_function, timeout, 
async_retry_strategy)
 
         from pyflink.fn_execution import flink_fn_execution_pb2
         j_python_data_stream_function_operator, j_output_type_info = \
             _get_one_input_stream_operator(
                 data_stream,
                 AsyncFunctionDescriptor(
-                    async_function, timeout, capacity,
+                    async_function, timeout, capacity, async_retry_strategy,
                     AsyncFunctionDescriptor.OutputMode.UNORDERED),
                 flink_fn_execution_pb2.UserDefinedDataStreamFunction.PROCESS,  
# type: ignore
                 output_type)
@@ -64,11 +90,21 @@ class AsyncDataStream(object):
             j_python_data_stream_function_operator))
 
     @staticmethod
-    def _validate(data_stream: DataStream, async_function: AsyncFunction) -> 
None:
+    def _validate(data_stream: DataStream, async_function: AsyncFunction,
+                  timeout: Time, async_retry_strategy: AsyncRetryStrategy) -> 
None:
         if not inspect.iscoroutinefunction(async_function.async_invoke):
             raise Exception("Method 'async_invoke' of class '%s' should be 
declared as 'async def'."
                             % type(async_function))
 
+        if async_retry_strategy is None:
+            raise Exception("Async retry strategy should not be None.")
+        if (async_retry_strategy != async_retry_strategies.NO_RETRY_STRATEGY 
and
+                timeout.to_milliseconds() <= 0):
+            raise Exception("Timeout should be configured when do async with 
retry.")
+        if async_retry_strategy.get_retry_predicate() is None:
+            raise Exception("Retry predicate of the async retry strategy '%s' 
is None."
+                            % type(async_retry_strategy))
+
         gateway = get_gateway()
         j_conf = 
get_j_env_configuration(data_stream._j_data_stream.getExecutionEnvironment())
         python_execution_mode = (
diff --git a/flink-python/pyflink/datastream/async_retry_strategies.py 
b/flink-python/pyflink/datastream/async_retry_strategies.py
new file mode 100644
index 00000000000..50333c817fb
--- /dev/null
+++ b/flink-python/pyflink/datastream/async_retry_strategies.py
@@ -0,0 +1,121 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import TypeVar, Callable, Optional, List, Generic
+
+from pyflink.datastream.functions import AsyncRetryStrategy, 
AsyncRetryPredicate
+
+__all__ = [
+    'RetryPredicate',
+    'NoRetryStrategy',
+    'FixedDelayRetryStrategy',
+    'ExponentialBackoffDelayRetryStrategy'
+]
+
+
+OUT = TypeVar('OUT')
+
+
+class RetryPredicate(AsyncRetryPredicate, Generic[OUT]):
+
+    def __init__(self,
+                 result_predicate: Optional[Callable[[List[OUT]], bool]],
+                 exception_predicate: Optional[Callable[[Exception], bool]]):
+        self._result_predicate = result_predicate
+        self._exception_predicate = exception_predicate
+
+    def result_predicate(self) -> Optional[Callable[[List[OUT]], bool]]:
+        return self._result_predicate
+
+    def exception_predicate(self) -> Optional[Callable[[Exception], bool]]:
+        return self._exception_predicate
+
+
+class NoRetryStrategy(AsyncRetryStrategy, Generic[OUT]):
+
+    def can_retry(self, current_attempts: int) -> bool:
+        return False
+
+    def get_backoff_time_millis(self, current_attempts: int) -> int:
+        return -1
+
+    def get_retry_predicate(self) -> AsyncRetryPredicate[OUT]:
+        return RetryPredicate(None, None)
+
+
+class FixedDelayRetryStrategy(AsyncRetryStrategy, Generic[OUT]):
+
+    def __init__(self,
+                 max_attempts: int,
+                 backoff_time_millis: int,
+                 result_predicate: Optional[Callable[[List[OUT]], bool]],
+                 exception_predicate: Optional[Callable[[Exception], bool]]):
+        assert max_attempts > 0, "max_attempts should be greater than zero."
+        assert backoff_time_millis > 0, "backoff_time_millis should be greater 
than zero."
+        self._max_attempts = max_attempts
+        self._backoff_time_millis = backoff_time_millis
+        self._result_predicate = result_predicate
+        self._exception_predicate = exception_predicate
+
+    def can_retry(self, current_attempts: int) -> bool:
+        return current_attempts <= self._max_attempts
+
+    def get_backoff_time_millis(self, current_attempts: int) -> int:
+        return self._backoff_time_millis
+
+    def get_retry_predicate(self) -> AsyncRetryPredicate[OUT]:
+        return RetryPredicate(self._result_predicate, 
self._exception_predicate)
+
+
+class ExponentialBackoffDelayRetryStrategy(AsyncRetryStrategy, Generic[OUT]):
+
+    def __init__(self,
+                 max_attempts: int,
+                 initial_delay: int,
+                 max_retry_delay: int,
+                 multiplier: float,
+                 result_predicate: Optional[Callable[[List[OUT]], bool]],
+                 exception_predicate: Optional[Callable[[Exception], bool]]):
+        assert max_attempts > 0, "max_attempts should be greater than zero."
+        assert initial_delay > 0, "initial_delay should be greater than zero."
+        assert max_retry_delay > 0, "max_retry_delay should be greater than 
zero."
+        assert multiplier > 0, "multiplier should be greater than zero."
+        self._max_attempts = max_attempts
+        self._max_retry_delay = max_retry_delay
+        self._multiplier = multiplier
+        self._result_predicate = result_predicate
+        self._exception_predicate = exception_predicate
+        self._initial_delay = initial_delay
+        self._last_retry_delay = initial_delay
+
+    def can_retry(self, current_attempts: int) -> bool:
+        return current_attempts <= self._max_attempts
+
+    def get_backoff_time_millis(self, current_attempts: int) -> int:
+        if current_attempts <= 1:
+            self._last_retry_delay = self._initial_delay
+            return self._last_retry_delay
+
+        backoff = int(min(self._last_retry_delay * self._multiplier, 
self._max_retry_delay))
+        self._last_retry_delay = backoff
+        return backoff
+
+    def get_retry_predicate(self) -> AsyncRetryPredicate[OUT]:
+        return RetryPredicate(self._result_predicate, 
self._exception_predicate)
+
+
+NO_RETRY_STRATEGY: AsyncRetryStrategy = NoRetryStrategy()
diff --git a/flink-python/pyflink/datastream/functions.py 
b/flink-python/pyflink/datastream/functions.py
index ea31b99db0b..f107b11b266 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -15,12 +15,11 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
################################################################################
-
 from abc import ABC, abstractmethod
 from enum import Enum
 
 from py4j.java_gateway import JavaObject
-from typing import Union, Any, Generic, TypeVar, Iterable, List
+from typing import Union, Any, Generic, TypeVar, Iterable, List, Callable, 
Optional
 
 from pyflink.datastream.state import ValueState, ValueStateDescriptor, 
ListStateDescriptor, \
     ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, 
ReducingState, \
@@ -57,7 +56,9 @@ __all__ = [
     'KeyedBroadcastProcessFunction',
     'AsyncFunction',
     'AsyncFunctionDescriptor',
-    'ResultFuture'
+    'ResultFuture',
+    'AsyncRetryPredicate',
+    'AsyncRetryStrategy',
 ]
 
 
@@ -902,6 +903,82 @@ register a timer that will trigger an action in the future.
         pass
 
 
+class AsyncRetryPredicate(ABC, Generic[OUT]):
+    """
+    Interface encapsulates an asynchronous retry predicate.
+    """
+
+    def result_predicate(self) -> Optional[Callable[[List[OUT]], bool]]:
+        """
+        An optional Python predicate function that defines a condition on 
asyncFunction's future
+        result which will trigger a later reattempt operation, will be called 
before user's
+        ResultFuture#complete.
+        """
+        pass
+
+    def exception_predicate(self) -> Optional[Callable[[Exception], bool]]:
+        """
+        An optional Python predicate function that defines a condition on 
asyncFunction's exception
+        which will trigger a later reattempt operation, will be called before 
user's
+        ResultFuture#complete_exceptionally.
+        """
+        pass
+
+
+class AsyncRetryStrategy(ABC, Generic[OUT]):
+    """
+    Interface encapsulates an asynchronous retry strategy.
+    """
+
+    def can_retry(self, current_attempts: int) -> bool:
+        """
+        Whether the next attempt can happen.
+        """
+        pass
+
+    def get_backoff_time_millis(self, current_attempts: int) -> int:
+        """
+        The delay time of next attempt.
+        """
+        pass
+
+    def get_retry_predicate(self) -> AsyncRetryPredicate[OUT]:
+        """
+        Returns the defined retry predicate.
+        """
+        pass
+
+    @staticmethod
+    def no_restart():
+        from pyflink.datastream import async_retry_strategies
+        return async_retry_strategies.NO_RETRY_STRATEGY
+
+    @staticmethod
+    def fixed_delay(
+        max_attempts: int,
+        backoff_time_millis: int,
+        result_predicate: Optional[Callable[[List[OUT]], bool]],
+        exception_predicate: Optional[Callable[[Exception], bool]]
+    ):
+        from pyflink.datastream.async_retry_strategies import 
FixedDelayRetryStrategy
+        return FixedDelayRetryStrategy(
+            max_attempts, backoff_time_millis, result_predicate, 
exception_predicate)
+
+    @staticmethod
+    def exponential_backoff(
+        max_attempts: int,
+        initial_delay: int,
+        max_retry_delay: int,
+        multiplier: float,
+        result_predicate: Optional[Callable[[List[OUT]], bool]],
+        exception_predicate: Optional[Callable[[Exception], bool]]
+    ):
+        from pyflink.datastream.async_retry_strategies import 
ExponentialBackoffDelayRetryStrategy
+        return ExponentialBackoffDelayRetryStrategy(
+            max_attempts, initial_delay, max_retry_delay, multiplier,
+            result_predicate, exception_predicate)
+
+
 class ResultFuture(Generic[OUT]):
     """
     Collects data / error in user codes while processing async i/o.
@@ -972,10 +1049,11 @@ class AsyncFunctionDescriptor(object):
         ORDERED = 0
         UNORDERED = 1
 
-    def __init__(self, async_function, timeout, capacity, output_mode):
+    def __init__(self, async_function, timeout, capacity, 
async_retry_strategy, output_mode):
         self.async_function = async_function
         self.timeout = timeout
         self.capacity = capacity
+        self.async_retry_strategy = async_retry_strategy
         self.output_mode = output_mode
 
 
diff --git a/flink-python/pyflink/datastream/tests/test_async_function.py 
b/flink-python/pyflink/datastream/tests/test_async_function.py
index 79b261b9020..30fe8d50022 100644
--- a/flink-python/pyflink/datastream/tests/test_async_function.py
+++ b/flink-python/pyflink/datastream/tests/test_async_function.py
@@ -17,10 +17,11 @@
 
################################################################################
 import asyncio
 import random
+from typing import List
 
 from pyflink.common import Types, Row, Time, Configuration, WatermarkStrategy
 from pyflink.datastream import AsyncDataStream, AsyncFunction, ResultFuture, \
-    StreamExecutionEnvironment
+    StreamExecutionEnvironment, AsyncRetryStrategy
 from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction, \
     SecondColumnTimestampAssigner
 from pyflink.java_gateway import get_gateway
@@ -226,6 +227,53 @@ class AsyncFunctionTests(PyFlinkStreamingTestCase):
         expected = ['0', '0', '0', '0', '0']
         self.assert_equals_sorted(expected, results)
 
+    def test_async_with_retry(self):
+        self.env.set_parallelism(1)
+        ds = self.env.from_collection(
+            [(1, 1), (2, 2), (3, 3)],
+            type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()])
+        )
+
+        class MyAsyncFunction(AsyncFunction):
+
+            def __init__(self):
+                self.retries_1 = {}
+                self.retries_2 = {}
+
+            async def async_invoke(self, value: Row, result_future: 
ResultFuture[int]):
+                await asyncio.sleep(1)
+                if value in self.retries_2:
+                    result_future.complete([value[0] + value[1]])
+                elif value in self.retries_1:
+                    self.retries_2[value] = True
+                    result_future.complete([value[0] + value[1] + 1])
+                else:
+                    self.retries_1[value] = True
+                    result_future.complete_exceptionally(ValueError("failed 
the first time"))
+
+            def timeout(self, value: Row, result_future: ResultFuture[int]):
+                result_future.complete([value[0] + value[1]])
+
+        def result_predicate(result: List[int]):
+            return result[0] % 2 == 1
+
+        def exception_predicate(exception: Exception):
+            return "failed the first time" in str(exception)
+
+        async_retry_strategy = AsyncRetryStrategy.fixed_delay(
+            max_attempts=5,
+            backoff_time_millis=1000,
+            result_predicate=result_predicate,
+            exception_predicate=exception_predicate
+        )
+        ds = AsyncDataStream.unordered_wait_with_retry(
+            ds, MyAsyncFunction(), Time.seconds(10), async_retry_strategy, 2, 
Types.INT())
+        ds.add_sink(self.test_sink)
+        self.env.execute()
+        results = self.test_sink.get_results(False)
+        expected = ['2', '4', '6']
+        self.assert_equals_sorted(expected, results)
+
 
 class EmbeddedThreadAsyncFunctionTests(PyFlinkStreamingTestCase):
 
diff --git 
a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py
 
b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py
index b60c5ddf02f..467ebc05ded 100644
--- 
a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py
+++ 
b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py
@@ -18,10 +18,11 @@
 import asyncio
 import pickle
 import threading
-from typing import TypeVar, Generic, List, Iterable, Callable
+from datetime import datetime
+from typing import TypeVar, Generic, List, Iterable, Callable, Optional
 
 from pyflink.datastream import RuntimeContext, ResultFuture
-from pyflink.datastream.functions import AsyncFunctionDescriptor
+from pyflink.datastream.functions import AsyncFunctionDescriptor, 
AsyncRetryStrategy
 from pyflink.fn_execution.datastream.process.async_function.queue import \
     UnorderedStreamElementQueue, StreamElementQueue
 from pyflink.fn_execution.datastream.process.operations import Operation
@@ -117,6 +118,83 @@ class ResultHandler(ResultFuture, Generic[IN, OUT]):
             self._timeout_func(self._record, self)
 
 
+class RetryableResultHandler(ResultFuture, Generic[IN, OUT]):
+
+    def __init__(self,
+                 result_handler: ResultHandler[IN, OUT],
+                 async_invoke_func_runner: Callable[[IN, ResultFuture[[OUT]]], 
None],
+                 retry_strategy: AsyncRetryStrategy[OUT]):
+        self._result_handler = result_handler
+        self._async_invoke_func_runner = async_invoke_func_runner
+        self._retry_strategy = retry_strategy
+        self._retry_result_predicate = \
+            retry_strategy.get_retry_predicate().result_predicate() or (lambda 
_: False)
+        self._retry_exception_predicate = \
+            retry_strategy.get_retry_predicate().exception_predicate() or 
(lambda _: False)
+        self._retry_awaiting = AtomicBoolean(False)
+        self._current_attempts = 1
+
+    def register_timeout(self, timeout):
+        timer = threading.Timer(timeout, self._timer_triggered)
+        timer.start()
+        self._result_handler._timer = timer
+        self._start_ts = datetime.now()
+        self._timeout = timeout
+
+    def complete(self, result: List[OUT]):
+        self._process_retry(result, None)
+
+    def complete_exceptionally(self, error: Exception):
+        self._process_retry(None, error)
+
+    def _process_retry(self, result: Optional[List[OUT]], error: 
Optional[Exception]):
+        if not self._retry_awaiting.compare_and_set(False, True):
+            return
+
+        satisfy = ((result is not None and 
self._retry_result_predicate(result)) or
+                   (error is not None and 
self._retry_exception_predicate(error)))
+
+        if (not self._is_timeout() and satisfy and
+                self._retry_strategy.can_retry(self._current_attempts)):
+
+            next_backoff_time_sec = 
self._retry_strategy.get_backoff_time_millis(
+                self._current_attempts) / 1000
+            self._delayed_retry_timer = threading.Timer(next_backoff_time_sec, 
self._do_retry)
+            self._delayed_retry_timer.start()
+        else:
+            if result is not None:
+                self._result_handler.complete(result)
+            else:
+                self._result_handler.complete_exceptionally(error)
+
+    def _is_timeout(self) -> bool:
+        diff = datetime.now() - self._start_ts
+        return diff.total_seconds() > self._timeout
+
+    def _do_retry(self):
+        if self._retry_awaiting.compare_and_set(True, False):
+            self._current_attempts += 1
+            self._async_invoke_func_runner(self._result_handler._record, self)
+
+    def _cancel_retry_timer(self):
+        if self._delayed_retry_timer is not None:
+            self._delayed_retry_timer.cancel()
+            self._delayed_retry_timer = None
+
+    def _timer_triggered(self):
+        """
+        Rewrite the timeout process to deal with retry state.
+        """
+        if not self._result_handler._completed.get():
+            # cancel delayed retry timer first
+            self._cancel_retry_timer()
+
+            # force reset _retry_awaiting to prevent the handler to trigger 
retry unnecessarily
+            self._retry_awaiting.set(False)
+
+            self._result_handler._timeout_func(self._result_handler._record, 
self)
+
+
 class Emitter(threading.Thread):
 
     def __init__(self,
@@ -195,6 +273,7 @@ class AsyncOperation(Operation):
             self.timeout_func,
             self._timeout,
             capacity,
+            self._async_retry_strategy,
             output_mode
         ) = extract_async_function(
             user_defined_function_proto=serialized_fn,
@@ -202,7 +281,13 @@ class AsyncOperation(Operation):
                 serialized_fn.runtime_context, self.base_metric_group
             )
         )
-        self._retry_enabled = False
+
+        self._result_predicate = 
self._async_retry_strategy.get_retry_predicate().result_predicate()
+        self._exception_predicate = (
+            
self._async_retry_strategy.get_retry_predicate().exception_predicate())
+        self._retry_enabled = (self._result_predicate is not None or
+                               self._exception_predicate is not None)
+
         if output_mode == AsyncFunctionDescriptor.OutputMode.UNORDERED:
             self._queue = UnorderedStreamElementQueue(capacity, 
self._raise_exception_if_exists)
         else:
@@ -251,13 +336,19 @@ class AsyncOperation(Operation):
         entry = self._queue.put(windowed_value, timestamp, watermark, record)
 
         if self._retry_enabled:
-            raise NotImplementedError
+            result_handler = ResultHandler(
+                self.class_name, self.timeout_func, self._mark_exception, 
record, entry)
+            retryable_result_handler = RetryableResultHandler(
+                result_handler, self._async_invoke_func_runner, 
self._async_retry_strategy)
+            # timeout is always > 0
+            retryable_result_handler.register_timeout(self._timeout)
+            self._async_invoke_func_runner(record, retryable_result_handler)
         else:
             result_handler = ResultHandler(
                 self.class_name, self.timeout_func, self._mark_exception, 
record, entry)
             if self._timeout > 0:
                 result_handler.register_timeout(self._timeout)
-            self._async_function_runner.run_async(self.async_invoke_func, 
record, result_handler)
+            self._async_invoke_func_runner(record, result_handler)
 
     def finish(self):
         self._wait_for_in_flight_inputs_finished()
@@ -275,6 +366,9 @@ class AsyncOperation(Operation):
         if self._exception is not None:
             raise self._exception
 
+    def _async_invoke_func_runner(self, record, result_handler):
+        self._async_function_runner.run_async(self.async_invoke_func, record, 
result_handler)
+
 
 def extract_async_function(user_defined_function_proto, runtime_context: 
RuntimeContext):
     """
@@ -289,6 +383,7 @@ def extract_async_function(user_defined_function_proto, 
runtime_context: Runtime
     class_name = type(async_function)
     timeout = async_function_descriptor.timeout.to_milliseconds() / 1000
     capacity = async_function_descriptor.capacity
+    async_retry_strategy = async_function_descriptor.async_retry_strategy
     output_mode = async_function_descriptor.output_mode
 
     def open_func():
@@ -303,4 +398,4 @@ def extract_async_function(user_defined_function_proto, 
runtime_context: Runtime
     timeout_func = async_function.timeout
 
     return (class_name, open_func, close_func, async_invoke_func, 
timeout_func, timeout, capacity,
-            output_mode)
+            async_retry_strategy, output_mode)

Reply via email to