This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4789a1f0568 Support Reschedule mode for `BaseSensorOperator` with Task 
SDK (#48193)
4789a1f0568 is described below

commit 4789a1f05681a534250012a20a13ddbd3367c70c
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 25 00:34:59 2025 +0530

    Support Reschedule mode for `BaseSensorOperator` with Task SDK (#48193)
    
    closes https://github.com/apache/airflow/issues/45580
    
    closes https://github.com/apache/airflow/issues/48088
    
    
    Example of a working DAG:
    <img width="1707" alt="image" 
src="https://github.com/user-attachments/assets/bb904395-b088-464e-b518-4335f6466cd4";
 />
    
    DAG Used:
    
    ```py
    import pendulum
    
    from airflow.decorators import dag, task
    from airflow.sdk import Variable
    
    
    @dag()
    def example_sensor_decorator():
    
        @task.sensor(poke_interval=5, timeout=60, mode="reschedule")
        def wait_for_upstream():
    
            if Variable.get("sensor_val", "") == "xcom_value":
                return True
            return False
    
        @task
        def dummy_operator() -> None:
            pass
    
        wait_for_upstream() >> dummy_operator()
    
    
    tutorial_etl_dag = example_sensor_decorator()
    ```
    
    **Change**:
    `BaseSensorOperator` needed some way to know `start_date` of the first 
reschedule for a specific ti_id. Currently, it gets that through DB. For that 
it needs to know `try_number` --- and that is retrieved from`retries` and 
`max_retries`
    
    
https://github.com/apache/airflow/blob/d264a8616ff79b0ca6d2fcb2c0140695314db910/airflow-core/src/airflow/sensors/base.py#L220-L230
    
    I needed some way to know "retries" in the API server to get the 
`try_number` for Task Reschedule.
    
    The option I had considered were:
    
    1) Passing `retries` in `TIEnterRunningPayload`
    2) Get `retries` from `serialized_dag` table.
    3) Deferred way of getting this as it is only required for sensors. (the 
changes in the PR)
---
 .../api_fastapi/execution_api/routes/__init__.py   |  4 +
 .../execution_api/routes/task_instances.py         | 17 ++++
 .../execution_api/routes/task_reschedules.py       | 56 ++++++++++++++
 airflow-core/src/airflow/models/taskinstance.py    | 24 ++++++
 airflow-core/src/airflow/sensors/base.py           | 46 +----------
 .../execution_api/routes/test_task_instances.py    | 90 ++++++++++++++++++++++
 airflow-core/tests/unit/sensors/test_base.py       | 22 ------
 task-sdk/src/airflow/sdk/api/client.py             |  6 ++
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 13 ++++
 .../src/airflow/sdk/execution_time/supervisor.py   |  4 +
 .../src/airflow/sdk/execution_time/task_runner.py  | 29 +++++++
 task-sdk/src/airflow/sdk/types.py                  |  2 +
 task-sdk/tests/conftest.py                         |  2 +
 task-sdk/tests/task_sdk/api/test_client.py         | 30 +++++++-
 .../task_sdk/execution_time/test_supervisor.py     | 11 +++
 .../task_sdk/execution_time/test_task_runner.py    | 28 +++++++
 tests/sdk/__init__.py                              | 16 ++++
 tests/sdk/api/__init__.py                          | 16 ++++
 tests/sdk/execution_time/__init__.py               | 16 ++++
 19 files changed, 365 insertions(+), 67 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 6610d5552f2..6b6c29b2bab 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -27,6 +27,7 @@ from airflow.api_fastapi.execution_api.routes import (
     dag_runs,
     health,
     task_instances,
+    task_reschedules,
     variables,
     xcoms,
 )
@@ -42,6 +43,9 @@ authenticated_router.include_router(asset_events.router, 
prefix="/asset-events",
 authenticated_router.include_router(connections.router, prefix="/connections", 
tags=["Connections"])
 authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", 
tags=["Dag Runs"])
 authenticated_router.include_router(task_instances.router, 
prefix="/task-instances", tags=["Task Instances"])
+authenticated_router.include_router(
+    task_reschedules.router, prefix="/task-reschedules", tags=["Task 
Reschedules"]
+)
 authenticated_router.include_router(variables.router, prefix="/variables", 
tags=["Variables"])
 authenticated_router.include_router(xcoms.router, prefix="/xcoms", 
tags=["XComs"])
 
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a4f20b18cdf..afc514f08f1 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -357,6 +357,23 @@ def ti_update_state(
         )
         updated_state = TaskInstanceState.DEFERRED
     elif isinstance(ti_patch_payload, TIRescheduleStatePayload):
+        # Quick check for poke_interval isn't immediately over MySQL's 
TIMESTAMP limit.
+        # This check is only rudimentary to catch trivial user errors, e.g. 
mistakenly
+        # set the value to milliseconds instead of seconds. There's another 
check when
+        # we actually try to reschedule to ensure database coherence.
+        if session.get_bind().dialect.name == "mysql":
+            # As documented in 
https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
+            _MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
+            if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
+                raise HTTPException(
+                    status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+                    detail={
+                        "reason": "invalid_reschedule_date",
+                        "message": f"Cannot reschedule to 
{ti_patch_payload.reschedule_date.isoformat()} "
+                        f"since it is over MySQL's TIMESTAMP storage limit.",
+                    },
+                )
+
         task_instance = session.get(TI, ti_id_str)
         actual_start_date = timezone.utcnow()
         session.add(
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
new file mode 100644
index 00000000000..f31ba38ccbd
--- /dev/null
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Annotated
+from uuid import UUID
+
+from fastapi import HTTPException, Query, status
+from sqlalchemy import select
+
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.common.types import UtcDateTime
+from airflow.models.taskreschedule import TaskReschedule
+
+router = AirflowRouter(
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+        status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
+    },
+)
+
+
[email protected]("/{task_instance_id}/start_date")
+def get_start_date(
+    task_instance_id: UUID, session: SessionDep, try_number: Annotated[int, 
Query()] = 1
+) -> UtcDateTime:
+    start_date = session.scalar(
+        select(TaskReschedule)
+        .where(
+            TaskReschedule.ti_id == str(task_instance_id),
+            TaskReschedule.try_number >= try_number,
+        )
+        .order_by(TaskReschedule.id.asc())
+        .with_only_columns(TaskReschedule.start_date)
+        .limit(1)
+    )
+    if start_date is None:
+        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
+
+    return start_date
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 107454c97d7..1d6a540073e 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -3688,6 +3688,30 @@ class TaskInstance(Base, LoggingMixin):
         }
         return asset_unique_keys - active_asset_unique_keys
 
+    def get_first_reschedule_date(self, context: Context) -> datetime | None:
+        """Get the first reschedule date for the task instance."""
+        # TODO: AIP-72: Remove this after `ti.run` is migrated to use Task SDK
+        max_tries: int = self.max_tries or 0
+
+        if TYPE_CHECKING:
+            assert isinstance(self.task, BaseOperator)
+
+        retries: int = self.task.retries or 0
+        first_try_number = max_tries - retries + 1
+
+        with create_session() as session:
+            start_date = session.scalar(
+                select(TaskReschedule)
+                .where(
+                    TaskReschedule.ti_id == str(self.id),
+                    TaskReschedule.try_number >= first_try_number,
+                )
+                .order_by(TaskReschedule.id.asc())
+                .with_only_columns(TaskReschedule.start_date)
+                .limit(1)
+            )
+        return start_date
+
 
 def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> 
MappedTaskGroup | None:
     """Given two operators, find their innermost common mapped task group."""
diff --git a/airflow-core/src/airflow/sensors/base.py 
b/airflow-core/src/airflow/sensors/base.py
index ad69eb2b914..167eafa45f4 100644
--- a/airflow-core/src/airflow/sensors/base.py
+++ b/airflow-core/src/airflow/sensors/base.py
@@ -26,8 +26,6 @@ from collections.abc import Iterable
 from datetime import timedelta
 from typing import TYPE_CHECKING, Any, Callable
 
-from sqlalchemy import select
-
 from airflow import settings
 from airflow.configuration import conf
 from airflow.exceptions import (
@@ -42,10 +40,8 @@ from airflow.exceptions import (
 )
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.utils import timezone
-from airflow.utils.session import create_session
 
 if TYPE_CHECKING:
     from airflow.sdk.definitions.context import Context
@@ -199,17 +195,6 @@ class BaseSensorOperator(BaseOperator):
                 f".{self.task_id}'; received '{self.mode}'."
             )
 
-        # Quick check for poke_interval isn't immediately over MySQL's 
TIMESTAMP limit.
-        # This check is only rudimentary to catch trivial user errors, e.g. 
mistakenly
-        # set the value to milliseconds instead of seconds. There's another 
check when
-        # we actually try to reschedule to ensure database coherence.
-        if self.reschedule and _is_metadatabase_mysql():
-            if timezone.utcnow() + 
datetime.timedelta(seconds=self.poke_interval) > _MYSQL_TIMESTAMP_MAX:
-                raise AirflowException(
-                    f"Cannot set poke_interval to {self.poke_interval} seconds 
in reschedule "
-                    f"mode since it will take reschedule time over MySQL's 
TIMESTAMP limit."
-                )
-
     def poke(self, context: Context) -> bool | PokeReturnValue:
         """Override when deriving this class."""
         raise AirflowException("Override me.")
@@ -219,29 +204,8 @@ class BaseSensorOperator(BaseOperator):
 
         if self.reschedule:
             ti = context["ti"]
-            max_tries: int = ti.max_tries or 0
-            retries: int = self.retries or 0
-
-            # If reschedule, use the start date of the first try (first try 
can be either the very
-            # first execution of the task, or the first execution after the 
task was cleared).
-            # If the first try's record was not saved due to the Exception 
occurred and the following
-            # transaction rollback, the next available attempt should be taken
-            # to prevent falling in the endless rescheduling
-            first_try_number = max_tries - retries + 1
-            with create_session() as session:
-                start_date = session.scalar(
-                    select(TaskReschedule)
-                    .where(
-                        TaskReschedule.ti_id == str(ti.id),
-                        TaskReschedule.try_number >= first_try_number,
-                    )
-                    .order_by(TaskReschedule.id.asc())
-                    .with_only_columns(TaskReschedule.start_date)
-                    .limit(1)
-                )
-            if not start_date:
-                start_date = timezone.utcnow()
-            started_at = start_date
+            first_reschedule_date = ti.get_first_reschedule_date(context)
+            started_at = start_date = first_reschedule_date or 
timezone.utcnow()
 
             def run_duration() -> float:
                 # If we are in reschedule mode, then we have to compute diff
@@ -255,7 +219,6 @@ class BaseSensorOperator(BaseOperator):
                 return time.monotonic() - start_monotonic
 
         poke_count = 1
-        log_dag_id = self.dag.dag_id if self.has_dag() else ""
 
         xcom_value = None
         while True:
@@ -301,11 +264,6 @@ class BaseSensorOperator(BaseOperator):
             if self.reschedule:
                 next_poke_interval = self._get_next_poke_interval(started_at, 
run_duration, poke_count)
                 reschedule_date = timezone.utcnow() + 
timedelta(seconds=next_poke_interval)
-                if _is_metadatabase_mysql() and reschedule_date > 
_MYSQL_TIMESTAMP_MAX:
-                    raise AirflowSensorTimeout(
-                        f"Cannot reschedule DAG {log_dag_id} to 
{reschedule_date.isoformat()} "
-                        f"since it is over MySQL's TIMESTAMP storage limit."
-                    )
                 raise AirflowRescheduleException(reschedule_date)
             else:
                 time.sleep(self._get_next_poke_interval(started_at, 
run_duration, poke_count))
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
index 578a019e252..2ab0f943e6d 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
@@ -712,6 +712,37 @@ class TestTIUpdateState:
         assert trs[0].task_instance.map_index == -1
         assert trs[0].duration == 129600
 
+    @pytest.mark.backend("mysql")
+    def test_ti_update_state_reschedule_mysql_limit(
+        self, client, session, create_task_instance, time_machine
+    ):
+        """Test that the reschedule date is validated against MySQL's 
TIMESTAMP limit."""
+        instant = timezone.datetime(2024, 10, 30)
+        time_machine.move_to(instant, tick=False)
+
+        ti = create_task_instance(
+            task_id="test_ti_update_state_reschedule_mysql_limit",
+            state=State.RUNNING,
+            session=session,
+        )
+        ti.start_date = instant
+        session.commit()
+
+        # Date beyond MySQL's TIMESTAMP limit (2038-01-19 03:14:07)
+        future_date = timezone.datetime(2038, 1, 19, 3, 14, 8)
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={
+                "state": TaskInstanceState.UP_FOR_RESCHEDULE,
+                "reschedule_date": future_date.isoformat(),
+                "end_date": DEFAULT_END_DATE.isoformat(),
+            },
+        )
+
+        assert response.status_code == 422
+        assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
+
     def test_ti_update_state_handle_retry(self, client, session, 
create_task_instance):
         ti = create_task_instance(
             task_id="test_ti_update_state_to_retry",
@@ -1129,3 +1160,62 @@ class TestPreviousDagRun:
             "start_date": None,
             "end_date": None,
         }
+
+
+class TestGetRescheduleStartDate:
+    def test_get_start_date(self, client, session, create_task_instance):
+        ti = create_task_instance(
+            task_id="test_ti_update_state_reschedule_mysql_limit",
+            state=State.RUNNING,
+            start_date=timezone.datetime(2024, 1, 1),
+            session=session,
+        )
+        tr = TaskReschedule(
+            task_instance_id=ti.id,
+            try_number=1,
+            start_date=timezone.datetime(2024, 1, 1),
+            end_date=timezone.datetime(2024, 1, 1, 1),
+            reschedule_date=timezone.datetime(2024, 1, 1, 2),
+        )
+        session.add(tr)
+        session.commit()
+
+        response = 
client.get(f"/execution/task-reschedules/{ti.id}/start_date")
+        assert response.status_code == 200
+        assert response.json() == "2024-01-01T00:00:00Z"
+
+    def test_get_start_date_not_found(self, client):
+        ti_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
+        response = 
client.get(f"/execution/task-reschedules/{ti_id}/start_date")
+        assert response.status_code == 404
+
+    def test_get_start_date_with_try_number(self, client, session, 
create_task_instance):
+        # Create multiple reschedules
+        dates = [
+            timezone.datetime(2024, 1, 1),
+            timezone.datetime(2024, 1, 2),
+            timezone.datetime(2024, 1, 3),
+        ]
+
+        ti = create_task_instance(
+            task_id="test_get_start_date_with_try_number",
+            state=State.RUNNING,
+            start_date=timezone.datetime(2024, 1, 1),
+            session=session,
+        )
+
+        for i, date in enumerate(dates, 1):
+            tr = TaskReschedule(
+                task_instance_id=ti.id,
+                try_number=i,
+                start_date=date,
+                end_date=date.replace(hour=1),
+                reschedule_date=date.replace(hour=2),
+            )
+            session.add(tr)
+        session.commit()
+
+        # Test getting start date for try_number 2
+        response = 
client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
+        assert response.status_code == 200
+        assert response.json() == "2024-01-02T00:00:00Z"
diff --git a/airflow-core/tests/unit/sensors/test_base.py 
b/airflow-core/tests/unit/sensors/test_base.py
index f1bedee703e..4b61b648d3d 100644
--- a/airflow-core/tests/unit/sensors/test_base.py
+++ b/airflow-core/tests/unit/sensors/test_base.py
@@ -739,28 +739,6 @@ class TestBaseSensor:
             for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]):
                 assert sensor._get_next_poke_interval(started_at, 
run_duration, idx) == expected
 
-    @pytest.mark.backend("mysql")
-    def test_reschedule_poke_interval_too_long_on_mysql(self, make_sensor):
-        with pytest.raises(AirflowException) as ctx:
-            make_sensor(poke_interval=863998946, mode="reschedule", 
return_value="irrelevant")
-        assert str(ctx.value) == (
-            "Cannot set poke_interval to 863998946.0 seconds in reschedule 
mode "
-            "since it will take reschedule time over MySQL's TIMESTAMP limit."
-        )
-
-    @pytest.mark.backend("mysql")
-    def test_reschedule_date_too_late_on_mysql(self, make_sensor):
-        sensor, _ = make_sensor(poke_interval=60 * 60 * 24, mode="reschedule", 
return_value=False)
-
-        # A few hours until TIMESTAMP's limit, the next poke will take us over.
-        with time_machine.travel(datetime(2038, 1, 19, tzinfo=timezone.utc), 
tick=False):
-            with pytest.raises(AirflowSensorTimeout) as ctx:
-                self._run(sensor)
-        assert str(ctx.value) == (
-            "Cannot reschedule DAG unit_test_dag to 2038-01-20T00:00:00+00:00 "
-            "since it is over MySQL's TIMESTAMP storage limit."
-        )
-
     def test_reschedule_and_retry_timeout(self, make_sensor, time_machine, 
session):
         """
         Test mode="reschedule", retries and timeout configurations interact 
correctly.
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 1822748ad72..6dc5123530b 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -64,6 +64,7 @@ from airflow.sdk.execution_time.comms import (
     OKResponse,
     RuntimeCheckOnTask,
     SkipDownstreamTasks,
+    TaskRescheduleStartDate,
 )
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import getuser
@@ -209,6 +210,11 @@ class TaskInstanceOperations:
                 return OKResponse(ok=True)
             raise
 
+    def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> 
TaskRescheduleStartDate:
+        """Get the start date of a task reschedule via the API server."""
+        resp = self.client.get(f"task-reschedules/{id}/start_date", 
params={"try_number": try_number})
+        return TaskRescheduleStartDate.model_construct(start_date=resp.json())
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index de20055b989..20867b22ca5 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -216,6 +216,11 @@ class 
PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
         return cls(**prev_dag_run.model_dump(exclude_defaults=True), 
type="PrevSuccessfulDagRunResult")
 
 
+class TaskRescheduleStartDate(BaseModel):
+    start_date: datetime
+    type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
+
+
 class ErrorResponse(BaseModel):
     error: ErrorType = ErrorType.GENERIC_ERROR
     detail: dict | None = None
@@ -236,6 +241,7 @@ ToTask = Annotated[
         ErrorResponse,
         PrevSuccessfulDagRunResult,
         StartupDetails,
+        TaskRescheduleStartDate,
         VariableResult,
         XComResult,
         XComCountResponse,
@@ -435,6 +441,12 @@ class GetPrevSuccessfulDagRun(BaseModel):
     type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
 
 
+class GetTaskRescheduleStartDate(BaseModel):
+    ti_id: UUID
+    try_number: int = 1
+    type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
+
+
 ToSupervisor = Annotated[
     Union[
         SucceedTask,
@@ -446,6 +458,7 @@ ToSupervisor = Annotated[
         GetConnection,
         GetDagRunState,
         GetPrevSuccessfulDagRun,
+        GetTaskRescheduleStartDate,
         GetVariable,
         GetXCom,
         GetXComCount,
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 2bc88563e1a..33ac30a1f6f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -75,6 +75,7 @@ from airflow.sdk.execution_time.comms import (
     GetConnection,
     GetDagRunState,
     GetPrevSuccessfulDagRun,
+    GetTaskRescheduleStartDate,
     GetVariable,
     GetXCom,
     GetXComCount,
@@ -962,6 +963,9 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, GetDagRunState):
             dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
             resp = 
DagRunStateResult.from_api_response(dr_resp).model_dump_json().encode()
+        elif isinstance(msg, GetTaskRescheduleStartDate):
+            tr_resp = 
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
+            resp = tr_resp.model_dump_json().encode()
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 31d44f1116f..1feaa494034 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,6 +58,7 @@ from airflow.sdk.execution_time.comms import (
     DeferTask,
     ErrorResponse,
     GetDagRunState,
+    GetTaskRescheduleStartDate,
     OKResponse,
     RescheduleTask,
     RetryTask,
@@ -66,6 +67,7 @@ from airflow.sdk.execution_time.comms import (
     SkipDownstreamTasks,
     StartupDetails,
     SucceedTask,
+    TaskRescheduleStartDate,
     TaskState,
     ToSupervisor,
     ToTask,
@@ -361,6 +363,33 @@ class RuntimeTaskInstance(TaskInstance):
         # TODO: Implement this method
         return None
 
+    def get_first_reschedule_date(self, context: Context) -> datetime | None:
+        """Get the first reschedule date for the task instance."""
+        if context.get("task_reschedule_count", 0) == 0:
+            # If the task has not been rescheduled, there is no need to ask 
the supervisor
+            return None
+
+        max_tries: int = self.max_tries
+        retries: int = self.task.retries or 0
+        first_try_number = max_tries - retries + 1
+
+        log = structlog.get_logger(logger_name="task")
+
+        log.debug("Requesting first reschedule date from supervisor")
+
+        SUPERVISOR_COMMS.send_request(
+            log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, 
try_number=first_try_number)
+        )
+        response = SUPERVISOR_COMMS.get_message()
+
+        if TYPE_CHECKING:
+            assert isinstance(response, TaskRescheduleStartDate)
+
+        start_date = response.start_date
+        log.debug("First reschedule date from supervisor: %s", start_date)
+
+        return start_date
+
 
 def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: 
int | None = None) -> None:
     # Private function, as we don't want to expose the ability to manually set 
`mapped_length` to SDK
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index 5c1c4b88c88..6c3c40f2ab9 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -83,6 +83,8 @@ class RuntimeTaskInstanceProtocol(Protocol):
 
     def get_template_context(self) -> Context: ...
 
+    def get_first_reschedule_date(self, first_try_number) -> datetime | None: 
...
+
 
 class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
     """Protocol for managing access to a specific outlet event accessor."""
diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py
index ee176c09c41..d1b6f29a038 100644
--- a/task-sdk/tests/conftest.py
+++ b/task-sdk/tests/conftest.py
@@ -364,6 +364,7 @@ def create_runtime_ti(mocked_parse, make_ti_context):
         try_number: int = 1,
         map_index: int | None = -1,
         upstream_map_indexes: dict[str, int] | None = None,
+        task_reschedule_count: int = 0,
         ti_id=None,
         conf=None,
     ) -> RuntimeTaskInstance:
@@ -382,6 +383,7 @@ def create_runtime_ti(mocked_parse, make_ti_context):
             start_date=start_date,
             run_type=run_type,
             conf=conf,
+            task_reschedule_count=task_reschedule_count,
         )
 
         if upstream_map_indexes is not None:
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index b7f48d0d3a5..5549178a794 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -35,7 +35,13 @@ from airflow.sdk.api.datamodels._generated import (
     XComResponse,
 )
 from airflow.sdk.exceptions import ErrorType
-from airflow.sdk.execution_time.comms import DeferTask, ErrorResponse, 
OKResponse, RescheduleTask
+from airflow.sdk.execution_time.comms import (
+    DeferTask,
+    ErrorResponse,
+    OKResponse,
+    RescheduleTask,
+    TaskRescheduleStartDate,
+)
 from airflow.utils import timezone
 from airflow.utils.state import TerminalTIState
 
@@ -813,3 +819,25 @@ class TestDagRunOperations:
         result = client.dag_runs.get_state(dag_id="test_state", 
run_id="test_run_id")
 
         assert result == DagRunStateResponse(state=DagRunState.RUNNING)
+
+
+class TestTaskRescheduleOperations:
+    def test_get_start_date(self):
+        """Test that the client can get the start date of a task reschedule"""
+        ti_id = uuid6.uuid7()
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-reschedules/{ti_id}/start_date":
+                assert request.url.params.get("try_number") == "1"
+
+                return httpx.Response(
+                    status_code=200,
+                    json="2024-01-01T00:00:00Z",
+                )
+            return httpx.Response(status_code=422)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_instances.get_reschedule_start_date(id=ti_id, 
try_number=1)
+
+        assert isinstance(result, TaskRescheduleStartDate)
+        assert result.start_date == "2024-01-01T00:00:00Z"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 5d06fe7992b..9751932b8e8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -69,6 +69,7 @@ from airflow.sdk.execution_time.comms import (
     GetConnection,
     GetDagRunState,
     GetPrevSuccessfulDagRun,
+    GetTaskRescheduleStartDate,
     GetVariable,
     GetXCom,
     OKResponse,
@@ -79,6 +80,7 @@ from airflow.sdk.execution_time.comms import (
     SetRenderedFields,
     SetXCom,
     SucceedTask,
+    TaskRescheduleStartDate,
     TaskState,
     TriggerDagRun,
     VariableResult,
@@ -1344,6 +1346,15 @@ class TestHandleRequest:
                 DagRunStateResult(state=DagRunState.RUNNING),
                 id="get_dag_run_state",
             ),
+            pytest.param(
+                GetTaskRescheduleStartDate(ti_id=TI_ID),
+                
b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n',
+                "task_instances.get_reschedule_start_date",
+                (TI_ID, 1),
+                {},
+                
TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
+                id="get_task_reschedule_start_date",
+            ),
         ],
     )
     def test_handle_requests(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 8f59eddbd11..b5a69538a56 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -77,6 +77,7 @@ from airflow.sdk.execution_time.comms import (
     SkipDownstreamTasks,
     StartupDetails,
     SucceedTask,
+    TaskRescheduleStartDate,
     TaskState,
     TriggerDagRun,
     VariableResult,
@@ -1417,6 +1418,33 @@ class TestRuntimeTaskInstance:
             log=mock.ANY,
         )
 
+    @pytest.mark.parametrize(
+        ["task_reschedule_count", "expected_date"],
+        [
+            (
+                0,
+                None,
+            ),
+            (
+                1,
+                timezone.datetime(2025, 1, 1),
+            ),
+        ],
+    )
+    def test_get_first_reschedule_date(
+        self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count, 
expected_date
+    ):
+        """Test that the first reschedule date is fetched from the 
Supervisor."""
+        task = BaseOperator(task_id="hello")
+        runtime_ti = create_runtime_ti(task=task, 
task_reschedule_count=task_reschedule_count)
+
+        mock_supervisor_comms.get_message.return_value = 
TaskRescheduleStartDate(
+            start_date=timezone.datetime(2025, 1, 1)
+        )
+
+        context = runtime_ti.get_template_context()
+        assert runtime_ti.get_first_reschedule_date(context=context) == 
expected_date
+
 
 class TestXComAfterTaskExecution:
     @pytest.mark.parametrize(
diff --git a/tests/sdk/__init__.py b/tests/sdk/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/sdk/api/__init__.py b/tests/sdk/api/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/api/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/sdk/execution_time/__init__.py 
b/tests/sdk/execution_time/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/execution_time/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.


Reply via email to