This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4789a1f0568 Support Reschedule mode for `BaseSensorOperator` with Task
SDK (#48193)
4789a1f0568 is described below
commit 4789a1f05681a534250012a20a13ddbd3367c70c
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 25 00:34:59 2025 +0530
Support Reschedule mode for `BaseSensorOperator` with Task SDK (#48193)
closes https://github.com/apache/airflow/issues/45580
closes https://github.com/apache/airflow/issues/48088
Example of a working DAG:
<img width="1707" alt="image"
src="https://github.com/user-attachments/assets/bb904395-b088-464e-b518-4335f6466cd4"
/>
DAG Used:
```py
import pendulum
from airflow.decorators import dag, task
from airflow.sdk import Variable
@dag()
def example_sensor_decorator():
@task.sensor(poke_interval=5, timeout=60, mode="reschedule")
def wait_for_upstream():
if Variable.get("sensor_val", "") == "xcom_value":
return True
return False
@task
def dummy_operator() -> None:
pass
wait_for_upstream() >> dummy_operator()
tutorial_etl_dag = example_sensor_decorator()
```
**Change**:
`BaseSensorOperator` needed some way to know `start_date` of the first
reschedule for a specific ti_id. Currently, it gets that through DB. For that
it needs to know `try_number` --- and that is retrieved from`retries` and
`max_retries`
https://github.com/apache/airflow/blob/d264a8616ff79b0ca6d2fcb2c0140695314db910/airflow-core/src/airflow/sensors/base.py#L220-L230
I needed some way to know "retries" in the API server to get the
`try_number` for Task Reschedule.
The option I had considered were:
1) Passing `retries` in `TIEnterRunningPayload`
2) Get `retries` from `serialized_dag` table.
3) Deferred way of getting this as it is only required for sensors. (the
changes in the PR)
---
.../api_fastapi/execution_api/routes/__init__.py | 4 +
.../execution_api/routes/task_instances.py | 17 ++++
.../execution_api/routes/task_reschedules.py | 56 ++++++++++++++
airflow-core/src/airflow/models/taskinstance.py | 24 ++++++
airflow-core/src/airflow/sensors/base.py | 46 +----------
.../execution_api/routes/test_task_instances.py | 90 ++++++++++++++++++++++
airflow-core/tests/unit/sensors/test_base.py | 22 ------
task-sdk/src/airflow/sdk/api/client.py | 6 ++
task-sdk/src/airflow/sdk/execution_time/comms.py | 13 ++++
.../src/airflow/sdk/execution_time/supervisor.py | 4 +
.../src/airflow/sdk/execution_time/task_runner.py | 29 +++++++
task-sdk/src/airflow/sdk/types.py | 2 +
task-sdk/tests/conftest.py | 2 +
task-sdk/tests/task_sdk/api/test_client.py | 30 +++++++-
.../task_sdk/execution_time/test_supervisor.py | 11 +++
.../task_sdk/execution_time/test_task_runner.py | 28 +++++++
tests/sdk/__init__.py | 16 ++++
tests/sdk/api/__init__.py | 16 ++++
tests/sdk/execution_time/__init__.py | 16 ++++
19 files changed, 365 insertions(+), 67 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 6610d5552f2..6b6c29b2bab 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -27,6 +27,7 @@ from airflow.api_fastapi.execution_api.routes import (
dag_runs,
health,
task_instances,
+ task_reschedules,
variables,
xcoms,
)
@@ -42,6 +43,9 @@ authenticated_router.include_router(asset_events.router,
prefix="/asset-events",
authenticated_router.include_router(connections.router, prefix="/connections",
tags=["Connections"])
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs",
tags=["Dag Runs"])
authenticated_router.include_router(task_instances.router,
prefix="/task-instances", tags=["Task Instances"])
+authenticated_router.include_router(
+ task_reschedules.router, prefix="/task-reschedules", tags=["Task
Reschedules"]
+)
authenticated_router.include_router(variables.router, prefix="/variables",
tags=["Variables"])
authenticated_router.include_router(xcoms.router, prefix="/xcoms",
tags=["XComs"])
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a4f20b18cdf..afc514f08f1 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -357,6 +357,23 @@ def ti_update_state(
)
updated_state = TaskInstanceState.DEFERRED
elif isinstance(ti_patch_payload, TIRescheduleStatePayload):
+ # Quick check for poke_interval isn't immediately over MySQL's
TIMESTAMP limit.
+ # This check is only rudimentary to catch trivial user errors, e.g.
mistakenly
+ # set the value to milliseconds instead of seconds. There's another
check when
+ # we actually try to reschedule to ensure database coherence.
+ if session.get_bind().dialect.name == "mysql":
+ # As documented in
https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
+ _MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
+ if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail={
+ "reason": "invalid_reschedule_date",
+ "message": f"Cannot reschedule to
{ti_patch_payload.reschedule_date.isoformat()} "
+ f"since it is over MySQL's TIMESTAMP storage limit.",
+ },
+ )
+
task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
session.add(
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
new file mode 100644
index 00000000000..f31ba38ccbd
--- /dev/null
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Annotated
+from uuid import UUID
+
+from fastapi import HTTPException, Query, status
+from sqlalchemy import select
+
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.common.types import UtcDateTime
+from airflow.models.taskreschedule import TaskReschedule
+
+router = AirflowRouter(
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+ status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
+ },
+)
+
+
[email protected]("/{task_instance_id}/start_date")
+def get_start_date(
+ task_instance_id: UUID, session: SessionDep, try_number: Annotated[int,
Query()] = 1
+) -> UtcDateTime:
+ start_date = session.scalar(
+ select(TaskReschedule)
+ .where(
+ TaskReschedule.ti_id == str(task_instance_id),
+ TaskReschedule.try_number >= try_number,
+ )
+ .order_by(TaskReschedule.id.asc())
+ .with_only_columns(TaskReschedule.start_date)
+ .limit(1)
+ )
+ if start_date is None:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
+
+ return start_date
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 107454c97d7..1d6a540073e 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -3688,6 +3688,30 @@ class TaskInstance(Base, LoggingMixin):
}
return asset_unique_keys - active_asset_unique_keys
+ def get_first_reschedule_date(self, context: Context) -> datetime | None:
+ """Get the first reschedule date for the task instance."""
+ # TODO: AIP-72: Remove this after `ti.run` is migrated to use Task SDK
+ max_tries: int = self.max_tries or 0
+
+ if TYPE_CHECKING:
+ assert isinstance(self.task, BaseOperator)
+
+ retries: int = self.task.retries or 0
+ first_try_number = max_tries - retries + 1
+
+ with create_session() as session:
+ start_date = session.scalar(
+ select(TaskReschedule)
+ .where(
+ TaskReschedule.ti_id == str(self.id),
+ TaskReschedule.try_number >= first_try_number,
+ )
+ .order_by(TaskReschedule.id.asc())
+ .with_only_columns(TaskReschedule.start_date)
+ .limit(1)
+ )
+ return start_date
+
def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) ->
MappedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
diff --git a/airflow-core/src/airflow/sensors/base.py
b/airflow-core/src/airflow/sensors/base.py
index ad69eb2b914..167eafa45f4 100644
--- a/airflow-core/src/airflow/sensors/base.py
+++ b/airflow-core/src/airflow/sensors/base.py
@@ -26,8 +26,6 @@ from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable
-from sqlalchemy import select
-
from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import (
@@ -42,10 +40,8 @@ from airflow.exceptions import (
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
-from airflow.utils.session import create_session
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@@ -199,17 +195,6 @@ class BaseSensorOperator(BaseOperator):
f".{self.task_id}'; received '{self.mode}'."
)
- # Quick check for poke_interval isn't immediately over MySQL's
TIMESTAMP limit.
- # This check is only rudimentary to catch trivial user errors, e.g.
mistakenly
- # set the value to milliseconds instead of seconds. There's another
check when
- # we actually try to reschedule to ensure database coherence.
- if self.reschedule and _is_metadatabase_mysql():
- if timezone.utcnow() +
datetime.timedelta(seconds=self.poke_interval) > _MYSQL_TIMESTAMP_MAX:
- raise AirflowException(
- f"Cannot set poke_interval to {self.poke_interval} seconds
in reschedule "
- f"mode since it will take reschedule time over MySQL's
TIMESTAMP limit."
- )
-
def poke(self, context: Context) -> bool | PokeReturnValue:
"""Override when deriving this class."""
raise AirflowException("Override me.")
@@ -219,29 +204,8 @@ class BaseSensorOperator(BaseOperator):
if self.reschedule:
ti = context["ti"]
- max_tries: int = ti.max_tries or 0
- retries: int = self.retries or 0
-
- # If reschedule, use the start date of the first try (first try
can be either the very
- # first execution of the task, or the first execution after the
task was cleared).
- # If the first try's record was not saved due to the Exception
occurred and the following
- # transaction rollback, the next available attempt should be taken
- # to prevent falling in the endless rescheduling
- first_try_number = max_tries - retries + 1
- with create_session() as session:
- start_date = session.scalar(
- select(TaskReschedule)
- .where(
- TaskReschedule.ti_id == str(ti.id),
- TaskReschedule.try_number >= first_try_number,
- )
- .order_by(TaskReschedule.id.asc())
- .with_only_columns(TaskReschedule.start_date)
- .limit(1)
- )
- if not start_date:
- start_date = timezone.utcnow()
- started_at = start_date
+ first_reschedule_date = ti.get_first_reschedule_date(context)
+ started_at = start_date = first_reschedule_date or
timezone.utcnow()
def run_duration() -> float:
# If we are in reschedule mode, then we have to compute diff
@@ -255,7 +219,6 @@ class BaseSensorOperator(BaseOperator):
return time.monotonic() - start_monotonic
poke_count = 1
- log_dag_id = self.dag.dag_id if self.has_dag() else ""
xcom_value = None
while True:
@@ -301,11 +264,6 @@ class BaseSensorOperator(BaseOperator):
if self.reschedule:
next_poke_interval = self._get_next_poke_interval(started_at,
run_duration, poke_count)
reschedule_date = timezone.utcnow() +
timedelta(seconds=next_poke_interval)
- if _is_metadatabase_mysql() and reschedule_date >
_MYSQL_TIMESTAMP_MAX:
- raise AirflowSensorTimeout(
- f"Cannot reschedule DAG {log_dag_id} to
{reschedule_date.isoformat()} "
- f"since it is over MySQL's TIMESTAMP storage limit."
- )
raise AirflowRescheduleException(reschedule_date)
else:
time.sleep(self._get_next_poke_interval(started_at,
run_duration, poke_count))
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
index 578a019e252..2ab0f943e6d 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
@@ -712,6 +712,37 @@ class TestTIUpdateState:
assert trs[0].task_instance.map_index == -1
assert trs[0].duration == 129600
+ @pytest.mark.backend("mysql")
+ def test_ti_update_state_reschedule_mysql_limit(
+ self, client, session, create_task_instance, time_machine
+ ):
+ """Test that the reschedule date is validated against MySQL's
TIMESTAMP limit."""
+ instant = timezone.datetime(2024, 10, 30)
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_ti_update_state_reschedule_mysql_limit",
+ state=State.RUNNING,
+ session=session,
+ )
+ ti.start_date = instant
+ session.commit()
+
+ # Date beyond MySQL's TIMESTAMP limit (2038-01-19 03:14:07)
+ future_date = timezone.datetime(2038, 1, 19, 3, 14, 8)
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": TaskInstanceState.UP_FOR_RESCHEDULE,
+ "reschedule_date": future_date.isoformat(),
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ },
+ )
+
+ assert response.status_code == 422
+ assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
+
def test_ti_update_state_handle_retry(self, client, session,
create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
@@ -1129,3 +1160,62 @@ class TestPreviousDagRun:
"start_date": None,
"end_date": None,
}
+
+
+class TestGetRescheduleStartDate:
+ def test_get_start_date(self, client, session, create_task_instance):
+ ti = create_task_instance(
+ task_id="test_ti_update_state_reschedule_mysql_limit",
+ state=State.RUNNING,
+ start_date=timezone.datetime(2024, 1, 1),
+ session=session,
+ )
+ tr = TaskReschedule(
+ task_instance_id=ti.id,
+ try_number=1,
+ start_date=timezone.datetime(2024, 1, 1),
+ end_date=timezone.datetime(2024, 1, 1, 1),
+ reschedule_date=timezone.datetime(2024, 1, 1, 2),
+ )
+ session.add(tr)
+ session.commit()
+
+ response =
client.get(f"/execution/task-reschedules/{ti.id}/start_date")
+ assert response.status_code == 200
+ assert response.json() == "2024-01-01T00:00:00Z"
+
+ def test_get_start_date_not_found(self, client):
+ ti_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
+ response =
client.get(f"/execution/task-reschedules/{ti_id}/start_date")
+ assert response.status_code == 404
+
+ def test_get_start_date_with_try_number(self, client, session,
create_task_instance):
+ # Create multiple reschedules
+ dates = [
+ timezone.datetime(2024, 1, 1),
+ timezone.datetime(2024, 1, 2),
+ timezone.datetime(2024, 1, 3),
+ ]
+
+ ti = create_task_instance(
+ task_id="test_get_start_date_with_try_number",
+ state=State.RUNNING,
+ start_date=timezone.datetime(2024, 1, 1),
+ session=session,
+ )
+
+ for i, date in enumerate(dates, 1):
+ tr = TaskReschedule(
+ task_instance_id=ti.id,
+ try_number=i,
+ start_date=date,
+ end_date=date.replace(hour=1),
+ reschedule_date=date.replace(hour=2),
+ )
+ session.add(tr)
+ session.commit()
+
+ # Test getting start date for try_number 2
+ response =
client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
+ assert response.status_code == 200
+ assert response.json() == "2024-01-02T00:00:00Z"
diff --git a/airflow-core/tests/unit/sensors/test_base.py
b/airflow-core/tests/unit/sensors/test_base.py
index f1bedee703e..4b61b648d3d 100644
--- a/airflow-core/tests/unit/sensors/test_base.py
+++ b/airflow-core/tests/unit/sensors/test_base.py
@@ -739,28 +739,6 @@ class TestBaseSensor:
for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]):
assert sensor._get_next_poke_interval(started_at,
run_duration, idx) == expected
- @pytest.mark.backend("mysql")
- def test_reschedule_poke_interval_too_long_on_mysql(self, make_sensor):
- with pytest.raises(AirflowException) as ctx:
- make_sensor(poke_interval=863998946, mode="reschedule",
return_value="irrelevant")
- assert str(ctx.value) == (
- "Cannot set poke_interval to 863998946.0 seconds in reschedule
mode "
- "since it will take reschedule time over MySQL's TIMESTAMP limit."
- )
-
- @pytest.mark.backend("mysql")
- def test_reschedule_date_too_late_on_mysql(self, make_sensor):
- sensor, _ = make_sensor(poke_interval=60 * 60 * 24, mode="reschedule",
return_value=False)
-
- # A few hours until TIMESTAMP's limit, the next poke will take us over.
- with time_machine.travel(datetime(2038, 1, 19, tzinfo=timezone.utc),
tick=False):
- with pytest.raises(AirflowSensorTimeout) as ctx:
- self._run(sensor)
- assert str(ctx.value) == (
- "Cannot reschedule DAG unit_test_dag to 2038-01-20T00:00:00+00:00 "
- "since it is over MySQL's TIMESTAMP storage limit."
- )
-
def test_reschedule_and_retry_timeout(self, make_sensor, time_machine,
session):
"""
Test mode="reschedule", retries and timeout configurations interact
correctly.
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 1822748ad72..6dc5123530b 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -64,6 +64,7 @@ from airflow.sdk.execution_time.comms import (
OKResponse,
RuntimeCheckOnTask,
SkipDownstreamTasks,
+ TaskRescheduleStartDate,
)
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
@@ -209,6 +210,11 @@ class TaskInstanceOperations:
return OKResponse(ok=True)
raise
+ def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) ->
TaskRescheduleStartDate:
+ """Get the start date of a task reschedule via the API server."""
+ resp = self.client.get(f"task-reschedules/{id}/start_date",
params={"try_number": try_number})
+ return TaskRescheduleStartDate.model_construct(start_date=resp.json())
+
class ConnectionOperations:
__slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index de20055b989..20867b22ca5 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -216,6 +216,11 @@ class
PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
return cls(**prev_dag_run.model_dump(exclude_defaults=True),
type="PrevSuccessfulDagRunResult")
+class TaskRescheduleStartDate(BaseModel):
+ start_date: datetime
+ type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
+
+
class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
@@ -236,6 +241,7 @@ ToTask = Annotated[
ErrorResponse,
PrevSuccessfulDagRunResult,
StartupDetails,
+ TaskRescheduleStartDate,
VariableResult,
XComResult,
XComCountResponse,
@@ -435,6 +441,12 @@ class GetPrevSuccessfulDagRun(BaseModel):
type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
+class GetTaskRescheduleStartDate(BaseModel):
+ ti_id: UUID
+ try_number: int = 1
+ type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
+
+
ToSupervisor = Annotated[
Union[
SucceedTask,
@@ -446,6 +458,7 @@ ToSupervisor = Annotated[
GetConnection,
GetDagRunState,
GetPrevSuccessfulDagRun,
+ GetTaskRescheduleStartDate,
GetVariable,
GetXCom,
GetXComCount,
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 2bc88563e1a..33ac30a1f6f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -75,6 +75,7 @@ from airflow.sdk.execution_time.comms import (
GetConnection,
GetDagRunState,
GetPrevSuccessfulDagRun,
+ GetTaskRescheduleStartDate,
GetVariable,
GetXCom,
GetXComCount,
@@ -962,6 +963,9 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, GetDagRunState):
dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
resp =
DagRunStateResult.from_api_response(dr_resp).model_dump_json().encode()
+ elif isinstance(msg, GetTaskRescheduleStartDate):
+ tr_resp =
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
+ resp = tr_resp.model_dump_json().encode()
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 31d44f1116f..1feaa494034 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,6 +58,7 @@ from airflow.sdk.execution_time.comms import (
DeferTask,
ErrorResponse,
GetDagRunState,
+ GetTaskRescheduleStartDate,
OKResponse,
RescheduleTask,
RetryTask,
@@ -66,6 +67,7 @@ from airflow.sdk.execution_time.comms import (
SkipDownstreamTasks,
StartupDetails,
SucceedTask,
+ TaskRescheduleStartDate,
TaskState,
ToSupervisor,
ToTask,
@@ -361,6 +363,33 @@ class RuntimeTaskInstance(TaskInstance):
# TODO: Implement this method
return None
+ def get_first_reschedule_date(self, context: Context) -> datetime | None:
+ """Get the first reschedule date for the task instance."""
+ if context.get("task_reschedule_count", 0) == 0:
+ # If the task has not been rescheduled, there is no need to ask
the supervisor
+ return None
+
+ max_tries: int = self.max_tries
+ retries: int = self.task.retries or 0
+ first_try_number = max_tries - retries + 1
+
+ log = structlog.get_logger(logger_name="task")
+
+ log.debug("Requesting first reschedule date from supervisor")
+
+ SUPERVISOR_COMMS.send_request(
+ log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id,
try_number=first_try_number)
+ )
+ response = SUPERVISOR_COMMS.get_message()
+
+ if TYPE_CHECKING:
+ assert isinstance(response, TaskRescheduleStartDate)
+
+ start_date = response.start_date
+ log.debug("First reschedule date from supervisor: %s", start_date)
+
+ return start_date
+
def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length:
int | None = None) -> None:
# Private function, as we don't want to expose the ability to manually set
`mapped_length` to SDK
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index 5c1c4b88c88..6c3c40f2ab9 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -83,6 +83,8 @@ class RuntimeTaskInstanceProtocol(Protocol):
def get_template_context(self) -> Context: ...
+ def get_first_reschedule_date(self, first_try_number) -> datetime | None:
...
+
class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
"""Protocol for managing access to a specific outlet event accessor."""
diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py
index ee176c09c41..d1b6f29a038 100644
--- a/task-sdk/tests/conftest.py
+++ b/task-sdk/tests/conftest.py
@@ -364,6 +364,7 @@ def create_runtime_ti(mocked_parse, make_ti_context):
try_number: int = 1,
map_index: int | None = -1,
upstream_map_indexes: dict[str, int] | None = None,
+ task_reschedule_count: int = 0,
ti_id=None,
conf=None,
) -> RuntimeTaskInstance:
@@ -382,6 +383,7 @@ def create_runtime_ti(mocked_parse, make_ti_context):
start_date=start_date,
run_type=run_type,
conf=conf,
+ task_reschedule_count=task_reschedule_count,
)
if upstream_map_indexes is not None:
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index b7f48d0d3a5..5549178a794 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -35,7 +35,13 @@ from airflow.sdk.api.datamodels._generated import (
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
-from airflow.sdk.execution_time.comms import DeferTask, ErrorResponse,
OKResponse, RescheduleTask
+from airflow.sdk.execution_time.comms import (
+ DeferTask,
+ ErrorResponse,
+ OKResponse,
+ RescheduleTask,
+ TaskRescheduleStartDate,
+)
from airflow.utils import timezone
from airflow.utils.state import TerminalTIState
@@ -813,3 +819,25 @@ class TestDagRunOperations:
result = client.dag_runs.get_state(dag_id="test_state",
run_id="test_run_id")
assert result == DagRunStateResponse(state=DagRunState.RUNNING)
+
+
+class TestTaskRescheduleOperations:
+ def test_get_start_date(self):
+ """Test that the client can get the start date of a task reschedule"""
+ ti_id = uuid6.uuid7()
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-reschedules/{ti_id}/start_date":
+ assert request.url.params.get("try_number") == "1"
+
+ return httpx.Response(
+ status_code=200,
+ json="2024-01-01T00:00:00Z",
+ )
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_reschedule_start_date(id=ti_id,
try_number=1)
+
+ assert isinstance(result, TaskRescheduleStartDate)
+ assert result.start_date == "2024-01-01T00:00:00Z"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 5d06fe7992b..9751932b8e8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -69,6 +69,7 @@ from airflow.sdk.execution_time.comms import (
GetConnection,
GetDagRunState,
GetPrevSuccessfulDagRun,
+ GetTaskRescheduleStartDate,
GetVariable,
GetXCom,
OKResponse,
@@ -79,6 +80,7 @@ from airflow.sdk.execution_time.comms import (
SetRenderedFields,
SetXCom,
SucceedTask,
+ TaskRescheduleStartDate,
TaskState,
TriggerDagRun,
VariableResult,
@@ -1344,6 +1346,15 @@ class TestHandleRequest:
DagRunStateResult(state=DagRunState.RUNNING),
id="get_dag_run_state",
),
+ pytest.param(
+ GetTaskRescheduleStartDate(ti_id=TI_ID),
+
b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n',
+ "task_instances.get_reschedule_start_date",
+ (TI_ID, 1),
+ {},
+
TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
+ id="get_task_reschedule_start_date",
+ ),
],
)
def test_handle_requests(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 8f59eddbd11..b5a69538a56 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -77,6 +77,7 @@ from airflow.sdk.execution_time.comms import (
SkipDownstreamTasks,
StartupDetails,
SucceedTask,
+ TaskRescheduleStartDate,
TaskState,
TriggerDagRun,
VariableResult,
@@ -1417,6 +1418,33 @@ class TestRuntimeTaskInstance:
log=mock.ANY,
)
+ @pytest.mark.parametrize(
+ ["task_reschedule_count", "expected_date"],
+ [
+ (
+ 0,
+ None,
+ ),
+ (
+ 1,
+ timezone.datetime(2025, 1, 1),
+ ),
+ ],
+ )
+ def test_get_first_reschedule_date(
+ self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count,
expected_date
+ ):
+ """Test that the first reschedule date is fetched from the
Supervisor."""
+ task = BaseOperator(task_id="hello")
+ runtime_ti = create_runtime_ti(task=task,
task_reschedule_count=task_reschedule_count)
+
+ mock_supervisor_comms.get_message.return_value =
TaskRescheduleStartDate(
+ start_date=timezone.datetime(2025, 1, 1)
+ )
+
+ context = runtime_ti.get_template_context()
+ assert runtime_ti.get_first_reschedule_date(context=context) ==
expected_date
+
class TestXComAfterTaskExecution:
@pytest.mark.parametrize(
diff --git a/tests/sdk/__init__.py b/tests/sdk/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/sdk/api/__init__.py b/tests/sdk/api/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/api/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/sdk/execution_time/__init__.py
b/tests/sdk/execution_time/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/tests/sdk/execution_time/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.