This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch tasksdk-call-listeners in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d3af041f38ea12f4e2c9b2f0bd98db4d1e28b648 Author: Maciej Obuchowski <obuchowski.mac...@gmail.com> AuthorDate: Wed Jan 8 13:24:44 2025 +0100 add basic support for task instance listeners in TaskSDK Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com> --- airflow/api_fastapi/execution_api/app.py | 3 +- .../api_fastapi/execution_api/datamodels/dagrun.py | 35 +++++ .../execution_api/datamodels/taskinstance.py | 5 + .../execution_api/routes/task_instances.py | 25 +++- airflow/executors/workloads.py | 14 ++ .../src/airflow/sdk/api/datamodels/_generated.py | 3 + .../src/airflow/sdk/execution_time/supervisor.py | 1 + .../src/airflow/sdk/execution_time/task_runner.py | 26 ++++ task_sdk/tests/conftest.py | 9 ++ task_sdk/tests/execution_time/test_supervisor.py | 38 ++--- task_sdk/tests/execution_time/test_task_runner.py | 159 ++++++++++++++++++++- .../execution_api/routes/test_task_instances.py | 2 + tests/callbacks/test_callback_requests.py | 1 + tests/executors/test_local_executor.py | 3 + tests/jobs/test_scheduler_job.py | 2 + tests/utils/test_cli_util.py | 2 +- 16 files changed, 300 insertions(+), 28 deletions(-) diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 61283dc2cf8..702ee9ad6a4 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -77,8 +77,9 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI: def get_extra_schemas() -> dict[str, dict]: """Get all the extra schemas that are not part of the main FastAPI app.""" - from airflow.api_fastapi.execution_api.datamodels import taskinstance + from airflow.api_fastapi.execution_api.datamodels import dagrun, taskinstance return { "TaskInstance": taskinstance.TaskInstance.model_json_schema(), + "DagRun": dagrun.DagRun.model_json_schema(), } diff --git a/airflow/api_fastapi/execution_api/datamodels/dagrun.py b/airflow/api_fastapi/execution_api/datamodels/dagrun.py new file mode 100644 index 00000000000..f9f99c1c43b --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/dagrun.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This model is not used in the API, but it is included in generated OpenAPI schema +# for use in the client SDKs. +from __future__ import annotations + +from airflow.api_fastapi.common.types import UtcDateTime +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagRun(BaseModel): + """Schema for TaskInstance model with minimal required fields needed for OL for now.""" + + id: int + dag_id: str + run_id: str + logical_date: UtcDateTime + data_interval_start: UtcDateTime + data_interval_end: UtcDateTime + clear_number: int diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 563b32a2693..0d09f5ed551 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -165,6 +165,7 @@ class TaskInstance(BaseModel): try_number: int map_index: int = -1 hostname: str | None = None + start_date: UtcDateTime class DagRun(BaseModel): @@ -181,6 +182,7 @@ class DagRun(BaseModel): data_interval_end: UtcDateTime | None start_date: UtcDateTime end_date: UtcDateTime | None + clear_number: int run_type: DagRunType conf: Annotated[dict[str, Any], Field(default_factory=dict)] @@ -191,6 +193,9 @@ class TIRunContext(BaseModel): dag_run: DagRun """DAG run information for the task instance.""" + task_reschedule_count: Annotated[int, Field(default=0)] + """How many times the task has been rescheduled.""" + max_tries: int """Maximum number of tries for the task instance (from DB).""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index ba6ea0c14b6..e543e9ce46f 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -23,7 +23,7 @@ from uuid import UUID from fastapi import Body, HTTPException, status from pydantic import JsonValue -from sqlalchemy import update +from sqlalchemy import func, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select @@ -75,14 +75,12 @@ def ti_run( ti_id_str = str(task_instance_id) old = ( - select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.max_tries) + select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.try_number, TI.max_tries) .where(TI.id == ti_id_str) .with_for_update() ) try: - (previous_state, dag_id, run_id, task_id, map_index, next_method, max_tries) = session.execute( - old - ).one() + (previous_state, dag_id, run_id, task_id, map_index, next_method, try_number, max_tries) = session.execute(old).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) raise HTTPException( @@ -142,6 +140,7 @@ def ti_run( DR.data_interval_end, DR.start_date, DR.end_date, + DR.clear_number, DR.run_type, DR.conf, DR.logical_date, @@ -165,8 +164,24 @@ def ti_run( session=session, ) + task_reschedule_count = ( + session.query( + func.count(TaskReschedule.id) # or any other primary key column + ) + .filter( + TaskReschedule.dag_id == dag_id, + TaskReschedule.task_id == ti_id_str, + TaskReschedule.run_id == run_id, + # TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks + TaskReschedule.try_number == try_number, + ) + .scalar() + or 0 + ) + return TIRunContext( dag_run=DagRun.model_validate(dr, from_attributes=True), + task_reschedule_count=task_reschedule_count, max_tries=max_tries, # TODO: Add variables and connections that are needed (and has perms) for the task variables=[], diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py index 9a5e425ef88..bc3019547b2 100644 --- a/airflow/executors/workloads.py +++ b/airflow/executors/workloads.py @@ -18,6 +18,7 @@ from __future__ import annotations import os import uuid +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Literal, Union @@ -56,6 +57,7 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int = -1 + start_date: datetime pool_slots: int queue: str @@ -75,6 +77,15 @@ class TaskInstance(BaseModel): ) +class DagRun(BaseModel): + id: int + dag_id: str + run_id: str + logical_date: datetime + data_interval_start: datetime + data_interval_end: datetime + + class ExecuteTask(BaseActivity): """Execute the given Task.""" @@ -96,6 +107,9 @@ class ExecuteTask(BaseActivity): from airflow.utils.helpers import log_filename_template_renderer + if not ti.start_date: + ti.start_date = datetime.now() + ser_ti = TaskInstance.model_validate(ti, from_attributes=True) bundle_info = BundleInfo.model_construct( name=ti.dag_model.bundle_name, diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index a8b478d07f0..a3a47821d6c 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -183,6 +183,7 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] + start_date: Annotated[datetime, Field(title="Start Date")] map_index: Annotated[int, Field(title="Map Index")] = -1 hostname: Annotated[str | None, Field(title="Hostname")] = None @@ -199,6 +200,7 @@ class DagRun(BaseModel): data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None start_date: Annotated[datetime, Field(title="Start Date")] end_date: Annotated[datetime | None, Field(title="End Date")] = None + clear_number: Annotated[int, Field(title="Clear Number")] = 0 run_type: DagRunType conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None @@ -213,6 +215,7 @@ class TIRunContext(BaseModel): """ dag_run: DagRun + task_reschedule_count: Annotated[int, Field(title="Task Reschedule Count")] = 0 max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 32895d36524..3a10d101e95 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -896,6 +896,7 @@ def supervise( Run a single task execution to completion. :param ti: The task instance to run. + :param dr: Current DagRun of the task instance. :param dag_path: The file path to the DAG. :param token: Authentication token for the API client. :param server: Base URL of the API server. diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 186faac878a..b0f81196c02 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -31,6 +31,7 @@ import attrs import structlog from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from airflow.listeners.listener import get_listener_manager from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager @@ -53,6 +54,7 @@ from airflow.sdk.execution_time.context import ( VariableAccessor, set_current_context, ) +from airflow.utils.state import TaskInstanceState from airflow.utils.net import get_hostname if TYPE_CHECKING: @@ -132,6 +134,7 @@ class RuntimeTaskInstance(TaskInstance): "ts": ts, "ts_nodash": ts_nodash, "ts_nodash_with_tz": ts_nodash_with_tz, + "task_reschedule_count": self._ti_context_from_server.task_reschedule_count, } context.update(context_from_server) @@ -458,6 +461,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): try: # TODO: pre execute etc. # TODO: Get a real context object + get_listener_manager().hook.on_task_instance_running( + previous_state=TaskInstanceState.QUEUED, task_instance=ti + ) ti.hostname = get_hostname() ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() @@ -474,6 +480,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): # - Pre Execute # etc msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) except TaskDeferred as defer: # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) @@ -508,6 +517,11 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), ) + + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) + # TODO: Run task failure callbacks here except (AirflowTaskTimeout, AirflowException): # We should allow retries if the task has defined it. @@ -516,6 +530,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), ) + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) # TODO: Run task failure callbacks here except AirflowTaskTerminated: # External state updates are already handled with `ti_heartbeat` and will be @@ -526,6 +543,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), ) + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) # TODO: Run task failure callbacks here except SystemExit: # SystemExit needs to be retried if they are eligible. @@ -534,10 +554,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), ) + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) # TODO: Run task failure callbacks here except BaseException: log.exception("Task failed with exception") # TODO: Run task failure callbacks here + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 50429d91b01..49cee7cac4b 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -142,8 +142,10 @@ class MakeTIContextCallable(Protocol): logical_date: str | datetime = ..., data_interval_start: str | datetime = ..., data_interval_end: str | datetime = ..., + clear_number: int = ..., start_date: str | datetime = ..., run_type: str = ..., + task_reschedule_count: int = ..., ) -> TIRunContext: ... @@ -157,6 +159,7 @@ class MakeTIContextDictCallable(Protocol): data_interval_end: str | datetime = ..., start_date: str | datetime = ..., run_type: str = ..., + task_reschedule_count: int = ..., ) -> dict[str, Any]: ... @@ -171,8 +174,10 @@ def make_ti_context() -> MakeTIContextCallable: logical_date: str | datetime = "2024-12-01T01:00:00Z", data_interval_start: str | datetime = "2024-12-01T00:00:00Z", data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + clear_number: int = 0, start_date: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", + task_reschedule_count: int = 0, ) -> TIRunContext: return TIRunContext( dag_run=DagRun( @@ -181,9 +186,11 @@ def make_ti_context() -> MakeTIContextCallable: logical_date=logical_date, # type: ignore data_interval_start=data_interval_start, # type: ignore data_interval_end=data_interval_end, # type: ignore + clear_number=clear_number, # type: ignore start_date=start_date, # type: ignore run_type=run_type, # type: ignore ), + task_reschedule_count=task_reschedule_count, max_tries=0, ) @@ -202,6 +209,7 @@ def make_ti_context_dict(make_ti_context: MakeTIContextCallable) -> MakeTIContex data_interval_end: str | datetime = "2024-12-01T01:00:00Z", start_date: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", + task_reschedule_count: int = 0, ) -> dict[str, Any]: context = make_ti_context( dag_id=dag_id, @@ -211,6 +219,7 @@ def make_ti_context_dict(make_ti_context: MakeTIContextCallable) -> MakeTIContex data_interval_end=data_interval_end, start_date=start_date, run_type=run_type, + task_reschedule_count=task_reschedule_count, ) return context.model_dump(exclude_unset=True, mode="json") diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 12c3455ccfe..40b4445e967 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -125,6 +125,7 @@ class TestWatchedSubprocess: dag_id="c", run_id="d", try_number=1, + start_date=instant, ), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, @@ -193,6 +194,7 @@ class TestWatchedSubprocess: dag_id="c", run_id="d", try_number=1, + start_date=tz.utcnow(), ), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, @@ -212,11 +214,7 @@ class TestWatchedSubprocess: dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance( - id=uuid7(), - task_id="b", - dag_id="c", - run_id="d", - try_number=1, + id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, start_date=tz.utcnow() ), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, @@ -249,11 +247,7 @@ class TestWatchedSubprocess: dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance( - id=ti_id, - task_id="b", - dag_id="c", - run_id="d", - try_number=1, + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, start_date=timezone.utcnow() ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, @@ -276,6 +270,7 @@ class TestWatchedSubprocess: dag_id="super_basic_run", run_id="c", try_number=1, + start_date=instant, ) bundle_info = BundleInfo.model_construct(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): @@ -307,9 +302,15 @@ class TestWatchedSubprocess: This includes ensuring the task starts and executes successfully, and that the task is deferred (via the API client) with the expected parameters. """ + instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0) ti = TaskInstance( - id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1 + id=uuid7(), + task_id="async", + dag_id="super_basic_deferred_run", + run_id="d", + try_number=1, + start_date=instant, ) # Create a mock client to assert calls to the client @@ -317,7 +318,6 @@ class TestWatchedSubprocess: mock_client = mocker.Mock(spec=sdk_client.Client) mock_client.task_instances.start.return_value = make_ti_context() - instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0) time_machine.move_to(instant, tick=False) bundle_info = BundleInfo.model_construct(name="my-bundle", version=None) @@ -357,7 +357,9 @@ class TestWatchedSubprocess: def test_supervisor_handles_already_running_task(self): """Test that Supervisor prevents starting a Task Instance that is already running.""" - ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1) + ti = TaskInstance( + id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, start_date=tz.utcnow() + ) # Mock API Server response indicating the TI is already running # The API Server would return a 409 Conflict status code if the TI is not @@ -433,7 +435,9 @@ class TestWatchedSubprocess: proc = ActivitySubprocess.start( dag_rel_path=os.devnull, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, start_date=tz.utcnow() + ), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, bundle_info=FAKE_BUNDLE, @@ -704,9 +708,11 @@ class TestWatchedSubprocessKill: ti_id = uuid7() proc = ActivitySubprocess.start( - dag_rel_path=os.devnull, + path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, start_date=tz.utcnow() + ), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, ) diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index f716317ad24..680032c5889 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -36,6 +36,8 @@ from airflow.exceptions import ( AirflowSkipException, AirflowTaskTerminated, ) +from airflow.listeners import hookimpl +from airflow.listeners.listener import get_listener_manager from airflow.sdk import DAG, BaseOperator, Connection, get_current_context from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.variable import Variable @@ -62,6 +64,7 @@ from airflow.sdk.execution_time.task_runner import ( startup, ) from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState FAKE_BUNDLE = BundleInfo.model_construct(name="anything", version="any") @@ -74,6 +77,42 @@ def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG: return dag +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + dag = get_inline_dag(dag_id, task) + t = dag.task_dict[task.task_id] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), task=t, _ti_context_from_server=what.ti_context + ) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + class CustomOperator(BaseOperator): def execute(self, context): task_id = context["task_instance"].task_id @@ -91,7 +130,8 @@ class TestCommsDecoder: w.makefile("wb").write( b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", "dag_id": "c" }, ' + b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' + b'"dag_id": "c", "start_date": "2024-12-01T01:00:00Z" }, ' b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",' b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' @@ -110,6 +150,7 @@ class TestCommsDecoder: assert msg.ti.dag_id == "c" assert msg.dag_rel_path == "/dev/null" assert msg.bundle_info == BundleInfo.model_construct(name="any-name", version="any-version") + assert msg.ti.start_date == timezone.datetime(2024, 12, 1, 1, 0) # Since this was a StartupDetails message, the decoder should open the other socket assert decoder.request_socket is not None @@ -120,7 +161,14 @@ class TestCommsDecoder: def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1), + ti=TaskInstance( + id=uuid7(), + task_id="a", + dag_id="super_basic", + run_id="c", + try_number=1, + start_date=timezone.utcnow(), + ), dag_rel_path="super_basic.py", bundle_info=BundleInfo.model_construct(name="my-bundle", version=None), requests_fd=0, @@ -346,7 +394,12 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi what = StartupDetails( ti=TaskInstance( - id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1 + id=uuid7(), + task_id="templated_task", + dag_id="basic_templated_dag", + run_id="c", + try_number=1, + start_date=timezone.datetime(2024, 12, 3, 10, 0), ), bundle_info=FAKE_BUNDLE, dag_rel_path="", @@ -414,9 +467,17 @@ def test_startup_and_run_dag_with_rtif( print(key, getattr(self, key)) task = CustomOperator(task_id="templated_task") + instant = timezone.datetime(2024, 12, 3, 10, 0) what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), + ti=TaskInstance( + id=uuid7(), + task_id="templated_task", + dag_id="basic_dag", + run_id="c", + try_number=1, + start_date=instant, + ), dag_rel_path="", bundle_info=FAKE_BUNDLE, requests_fd=0, @@ -424,7 +485,6 @@ def test_startup_and_run_dag_with_rtif( ) ti = mocked_parse(what, "basic_dag", task) - instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) mock_supervisor_comms.get_message.return_value = what @@ -654,6 +714,7 @@ class TestRuntimeTaskInstance: "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0), "data_interval_start": timezone.datetime(2024, 12, 1, 0, 0, 0), "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0), + "task_reschedule_count": 0, "ds": "2024-12-01", "ds_nodash": "20241201", "expanded_ti_count": None, @@ -950,3 +1011,91 @@ class TestXComAfterTaskExecution: assert str(exc_info.value) == ( f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead" ) + + +class TestTaskRunnerCallsListeners: + class CustomListener: + def __init__(self): + self.state = [] + + @hookimpl + def on_task_instance_running(self, previous_state, task_instance): + self.state.append(TaskInstanceState.RUNNING) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance): + self.state.append(TaskInstanceState.SUCCESS) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance): + self.state.append(TaskInstanceState.FAILED) + + @pytest.fixture(autouse=True) + def clean_listener_manager(self): + lm = get_listener_manager() + lm.clear() + yield + lm = get_listener_manager() + lm.clear() + + def test_task_runner_calls_listeners_success(self, mocked_parse, mock_supervisor_comms): + listener = self.CustomListener() + get_listener_manager().add_listener(listener) + + class CustomOperator(BaseOperator): + def execute(self, context): + self.value = "something" + + task = CustomOperator( + task_id="test_task_runner_calls_listeners", do_xcom_push=True, multiple_outputs=True + ) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + start_date=timezone.utcnow(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), task=task) + + run(runtime_ti, log=mock.MagicMock()) + + assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] + + @pytest.mark.parametrize( + "exception", + [ + ValueError("oops"), + SystemExit("oops"), + AirflowException("oops"), + ], + ) + def test_task_runner_calls_listeners_failed(self, mocked_parse, mock_supervisor_comms, exception): + listener = self.CustomListener() + get_listener_manager().add_listener(listener) + + class CustomOperator(BaseOperator): + def execute(self, context): + raise exception + + task = CustomOperator( + task_id="test_task_runner_calls_listeners_failed", do_xcom_push=True, multiple_outputs=True + ) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + start_date=timezone.utcnow(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), task=task) + + run(runtime_ti, log=mock.MagicMock()) + + assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index e6da0f3a192..bcb30c4b986 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -80,6 +80,7 @@ class TestTIRunState: "dag_run": { "dag_id": "dag", "run_id": "test", + "clear_number": 0, "logical_date": instant_str, "data_interval_start": instant.subtract(days=1).to_iso8601_string(), "data_interval_end": instant_str, @@ -88,6 +89,7 @@ class TestTIRunState: "run_type": "manual", "conf": {}, }, + "task_reschedule_count": 0, "max_tries": 0, "variables": [], "connections": [], diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index 68362d36239..b07b4e742b5 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -64,6 +64,7 @@ class TestCallbackRequest: run_id="fake_run", state=State.RUNNING, ) + ti.start_date = timezone.utcnow() input = TaskCallbackRequest( full_filepath="filepath", diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 673457509ca..ec72863b20e 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -27,6 +27,7 @@ from uuid6 import uuid7 from airflow.executors import workloads from airflow.executors.local_executor import LocalExecutor +from airflow.utils import timezone from airflow.utils.state import State pytestmark = pytest.mark.db_test @@ -63,6 +64,8 @@ class TestLocalExecutor: pool_slots=1, queue="default", priority_weight=1, + map_index=-1, + start_date=timezone.utcnow(), ) for i in range(self.TEST_SUCCESS_COMMANDS) ] diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 0ea27f97bb8..06000886e89 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -5667,6 +5667,7 @@ class TestSchedulerJob: ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6) + ti.start_date = timezone.utcnow() - timedelta(minutes=10) ti.queued_by_job_id = 999 session.add(ti) @@ -5782,6 +5783,7 @@ class TestSchedulerJob: ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6) + ti.start_date = timezone.utcnow() - timedelta(minutes=10) # TODO: If there was an actual Relationship between TI and Job # we wouldn't need this extra commit diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py index 8632bf2726d..38d46a3275b 100644 --- a/tests/utils/test_cli_util.py +++ b/tests/utils/test_cli_util.py @@ -81,7 +81,7 @@ class TestCliUtil: assert os.path.join(settings.DAGS_FOLDER, "abc") == cli.process_subdir("DAGS_FOLDER/abc") def test_get_dags(self): - dags = cli.get_dags(None, "example_bash_operator") + dags = cli.get_dags(None, "test_example_bash_operator") assert len(dags) == 1 with pytest.raises(AirflowException):