This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch tasksdk-call-listeners
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d3af041f38ea12f4e2c9b2f0bd98db4d1e28b648
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Wed Jan 8 13:24:44 2025 +0100

    add basic support for task instance listeners in TaskSDK
    
    Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com>
---
 airflow/api_fastapi/execution_api/app.py           |   3 +-
 .../api_fastapi/execution_api/datamodels/dagrun.py |  35 +++++
 .../execution_api/datamodels/taskinstance.py       |   5 +
 .../execution_api/routes/task_instances.py         |  25 +++-
 airflow/executors/workloads.py                     |  14 ++
 .../src/airflow/sdk/api/datamodels/_generated.py   |   3 +
 .../src/airflow/sdk/execution_time/supervisor.py   |   1 +
 .../src/airflow/sdk/execution_time/task_runner.py  |  26 ++++
 task_sdk/tests/conftest.py                         |   9 ++
 task_sdk/tests/execution_time/test_supervisor.py   |  38 ++---
 task_sdk/tests/execution_time/test_task_runner.py  | 159 ++++++++++++++++++++-
 .../execution_api/routes/test_task_instances.py    |   2 +
 tests/callbacks/test_callback_requests.py          |   1 +
 tests/executors/test_local_executor.py             |   3 +
 tests/jobs/test_scheduler_job.py                   |   2 +
 tests/utils/test_cli_util.py                       |   2 +-
 16 files changed, 300 insertions(+), 28 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index 61283dc2cf8..702ee9ad6a4 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -77,8 +77,9 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
 
 def get_extra_schemas() -> dict[str, dict]:
     """Get all the extra schemas that are not part of the main FastAPI app."""
-    from airflow.api_fastapi.execution_api.datamodels import taskinstance
+    from airflow.api_fastapi.execution_api.datamodels import dagrun, 
taskinstance
 
     return {
         "TaskInstance": taskinstance.TaskInstance.model_json_schema(),
+        "DagRun": dagrun.DagRun.model_json_schema(),
     }
diff --git a/airflow/api_fastapi/execution_api/datamodels/dagrun.py 
b/airflow/api_fastapi/execution_api/datamodels/dagrun.py
new file mode 100644
index 00000000000..f9f99c1c43b
--- /dev/null
+++ b/airflow/api_fastapi/execution_api/datamodels/dagrun.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This model is not used in the API, but it is included in generated OpenAPI 
schema
+# for use in the client SDKs.
+from __future__ import annotations
+
+from airflow.api_fastapi.common.types import UtcDateTime
+from airflow.api_fastapi.core_api.base import BaseModel
+
+
+class DagRun(BaseModel):
+    """Schema for TaskInstance model with minimal required fields needed for 
OL for now."""
+
+    id: int
+    dag_id: str
+    run_id: str
+    logical_date: UtcDateTime
+    data_interval_start: UtcDateTime
+    data_interval_end: UtcDateTime
+    clear_number: int
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 563b32a2693..0d09f5ed551 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -165,6 +165,7 @@ class TaskInstance(BaseModel):
     try_number: int
     map_index: int = -1
     hostname: str | None = None
+    start_date: UtcDateTime
 
 
 class DagRun(BaseModel):
@@ -181,6 +182,7 @@ class DagRun(BaseModel):
     data_interval_end: UtcDateTime | None
     start_date: UtcDateTime
     end_date: UtcDateTime | None
+    clear_number: int
     run_type: DagRunType
     conf: Annotated[dict[str, Any], Field(default_factory=dict)]
 
@@ -191,6 +193,9 @@ class TIRunContext(BaseModel):
     dag_run: DagRun
     """DAG run information for the task instance."""
 
+    task_reschedule_count: Annotated[int, Field(default=0)]
+    """How many times the task has been rescheduled."""
+
     max_tries: int
     """Maximum number of tries for the task instance (from DB)."""
 
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index ba6ea0c14b6..e543e9ce46f 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -23,7 +23,7 @@ from uuid import UUID
 
 from fastapi import Body, HTTPException, status
 from pydantic import JsonValue
-from sqlalchemy import update
+from sqlalchemy import func, update
 from sqlalchemy.exc import NoResultFound, SQLAlchemyError
 from sqlalchemy.sql import select
 
@@ -75,14 +75,12 @@ def ti_run(
     ti_id_str = str(task_instance_id)
 
     old = (
-        select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, 
TI.next_method, TI.max_tries)
+        select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, 
TI.next_method, TI.try_number, TI.max_tries)
         .where(TI.id == ti_id_str)
         .with_for_update()
     )
     try:
-        (previous_state, dag_id, run_id, task_id, map_index, next_method, 
max_tries) = session.execute(
-            old
-        ).one()
+        (previous_state, dag_id, run_id, task_id, map_index, next_method, 
try_number, max_tries) = session.execute(old).one()
     except NoResultFound:
         log.error("Task Instance %s not found", ti_id_str)
         raise HTTPException(
@@ -142,6 +140,7 @@ def ti_run(
                 DR.data_interval_end,
                 DR.start_date,
                 DR.end_date,
+                DR.clear_number,
                 DR.run_type,
                 DR.conf,
                 DR.logical_date,
@@ -165,8 +164,24 @@ def ti_run(
                 session=session,
             )
 
+        task_reschedule_count = (
+            session.query(
+                func.count(TaskReschedule.id)  # or any other primary key 
column
+            )
+            .filter(
+                TaskReschedule.dag_id == dag_id,
+                TaskReschedule.task_id == ti_id_str,
+                TaskReschedule.run_id == run_id,
+                #    TaskReschedule.map_index == ti.map_index,  # TODO: Handle 
mapped tasks
+                TaskReschedule.try_number == try_number,
+            )
+            .scalar()
+            or 0
+        )
+
         return TIRunContext(
             dag_run=DagRun.model_validate(dr, from_attributes=True),
+            task_reschedule_count=task_reschedule_count,
             max_tries=max_tries,
             # TODO: Add variables and connections that are needed (and has 
perms) for the task
             variables=[],
diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py
index 9a5e425ef88..bc3019547b2 100644
--- a/airflow/executors/workloads.py
+++ b/airflow/executors/workloads.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import os
 import uuid
+from datetime import datetime
 from pathlib import Path
 from typing import TYPE_CHECKING, Literal, Union
 
@@ -56,6 +57,7 @@ class TaskInstance(BaseModel):
     run_id: str
     try_number: int
     map_index: int = -1
+    start_date: datetime
 
     pool_slots: int
     queue: str
@@ -75,6 +77,15 @@ class TaskInstance(BaseModel):
         )
 
 
+class DagRun(BaseModel):
+    id: int
+    dag_id: str
+    run_id: str
+    logical_date: datetime
+    data_interval_start: datetime
+    data_interval_end: datetime
+
+
 class ExecuteTask(BaseActivity):
     """Execute the given Task."""
 
@@ -96,6 +107,9 @@ class ExecuteTask(BaseActivity):
 
         from airflow.utils.helpers import log_filename_template_renderer
 
+        if not ti.start_date:
+            ti.start_date = datetime.now()
+
         ser_ti = TaskInstance.model_validate(ti, from_attributes=True)
         bundle_info = BundleInfo.model_construct(
             name=ti.dag_model.bundle_name,
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index a8b478d07f0..a3a47821d6c 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -183,6 +183,7 @@ class TaskInstance(BaseModel):
     dag_id: Annotated[str, Field(title="Dag Id")]
     run_id: Annotated[str, Field(title="Run Id")]
     try_number: Annotated[int, Field(title="Try Number")]
+    start_date: Annotated[datetime, Field(title="Start Date")]
     map_index: Annotated[int, Field(title="Map Index")] = -1
     hostname: Annotated[str | None, Field(title="Hostname")] = None
 
@@ -199,6 +200,7 @@ class DagRun(BaseModel):
     data_interval_end: Annotated[datetime | None, Field(title="Data Interval 
End")] = None
     start_date: Annotated[datetime, Field(title="Start Date")]
     end_date: Annotated[datetime | None, Field(title="End Date")] = None
+    clear_number: Annotated[int, Field(title="Clear Number")] = 0
     run_type: DagRunType
     conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
 
@@ -213,6 +215,7 @@ class TIRunContext(BaseModel):
     """
 
     dag_run: DagRun
+    task_reschedule_count: Annotated[int, Field(title="Task Reschedule 
Count")] = 0
     max_tries: Annotated[int, Field(title="Max Tries")]
     variables: Annotated[list[VariableResponse] | None, 
Field(title="Variables")] = None
     connections: Annotated[list[ConnectionResponse] | None, 
Field(title="Connections")] = None
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 32895d36524..3a10d101e95 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -896,6 +896,7 @@ def supervise(
     Run a single task execution to completion.
 
     :param ti: The task instance to run.
+    :param dr: Current DagRun of the task instance.
     :param dag_path: The file path to the DAG.
     :param token: Authentication token for the API client.
     :param server: Base URL of the API server.
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 186faac878a..b0f81196c02 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -31,6 +31,7 @@ import attrs
 import structlog
 from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
 
+from airflow.listeners.listener import get_listener_manager
 from airflow.dag_processing.bundles.manager import DagBundlesManager
 from airflow.sdk.api.datamodels._generated import TaskInstance, 
TerminalTIState, TIRunContext
 from airflow.sdk.definitions._internal.dag_parsing_context import 
_airflow_parsing_context_manager
@@ -53,6 +54,7 @@ from airflow.sdk.execution_time.context import (
     VariableAccessor,
     set_current_context,
 )
+from airflow.utils.state import TaskInstanceState
 from airflow.utils.net import get_hostname
 
 if TYPE_CHECKING:
@@ -132,6 +134,7 @@ class RuntimeTaskInstance(TaskInstance):
                 "ts": ts,
                 "ts_nodash": ts_nodash,
                 "ts_nodash_with_tz": ts_nodash_with_tz,
+                "task_reschedule_count": 
self._ti_context_from_server.task_reschedule_count,
             }
             context.update(context_from_server)
 
@@ -458,6 +461,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
     try:
         # TODO: pre execute etc.
         # TODO: Get a real context object
+        get_listener_manager().hook.on_task_instance_running(
+            previous_state=TaskInstanceState.QUEUED, task_instance=ti
+        )
         ti.hostname = get_hostname()
         ti.task = ti.task.prepare_for_execution()
         context = ti.get_template_context()
@@ -474,6 +480,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         #   - Pre Execute
         #   etc
         msg = TaskState(state=TerminalTIState.SUCCESS, 
end_date=datetime.now(tz=timezone.utc))
+        get_listener_manager().hook.on_task_instance_success(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
     except TaskDeferred as defer:
         # TODO: Should we use structlog.bind_contextvars here for dag_id, 
task_id & run_id?
         log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, 
task_id=ti.task_id, run_id=ti.run_id)
@@ -508,6 +517,11 @@ def run(ti: RuntimeTaskInstance, log: Logger):
             state=TerminalTIState.FAIL_WITHOUT_RETRY,
             end_date=datetime.now(tz=timezone.utc),
         )
+
+        get_listener_manager().hook.on_task_instance_failed(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
+
         # TODO: Run task failure callbacks here
     except (AirflowTaskTimeout, AirflowException):
         # We should allow retries if the task has defined it.
@@ -516,6 +530,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
             state=TerminalTIState.FAILED,
             end_date=datetime.now(tz=timezone.utc),
         )
+        get_listener_manager().hook.on_task_instance_failed(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
         # TODO: Run task failure callbacks here
     except AirflowTaskTerminated:
         # External state updates are already handled with `ti_heartbeat` and 
will be
@@ -526,6 +543,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
             state=TerminalTIState.FAIL_WITHOUT_RETRY,
             end_date=datetime.now(tz=timezone.utc),
         )
+        get_listener_manager().hook.on_task_instance_failed(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
         # TODO: Run task failure callbacks here
     except SystemExit:
         # SystemExit needs to be retried if they are eligible.
@@ -534,10 +554,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
             state=TerminalTIState.FAILED,
             end_date=datetime.now(tz=timezone.utc),
         )
+        get_listener_manager().hook.on_task_instance_failed(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
         # TODO: Run task failure callbacks here
     except BaseException:
         log.exception("Task failed with exception")
         # TODO: Run task failure callbacks here
+        get_listener_manager().hook.on_task_instance_failed(
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+        )
         msg = TaskState(state=TerminalTIState.FAILED, 
end_date=datetime.now(tz=timezone.utc))
     if msg:
         SUPERVISOR_COMMS.send_request(msg=msg, log=log)
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 50429d91b01..49cee7cac4b 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -142,8 +142,10 @@ class MakeTIContextCallable(Protocol):
         logical_date: str | datetime = ...,
         data_interval_start: str | datetime = ...,
         data_interval_end: str | datetime = ...,
+        clear_number: int = ...,
         start_date: str | datetime = ...,
         run_type: str = ...,
+        task_reschedule_count: int = ...,
     ) -> TIRunContext: ...
 
 
@@ -157,6 +159,7 @@ class MakeTIContextDictCallable(Protocol):
         data_interval_end: str | datetime = ...,
         start_date: str | datetime = ...,
         run_type: str = ...,
+        task_reschedule_count: int = ...,
     ) -> dict[str, Any]: ...
 
 
@@ -171,8 +174,10 @@ def make_ti_context() -> MakeTIContextCallable:
         logical_date: str | datetime = "2024-12-01T01:00:00Z",
         data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
         data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+        clear_number: int = 0,
         start_date: str | datetime = "2024-12-01T01:00:00Z",
         run_type: str = "manual",
+        task_reschedule_count: int = 0,
     ) -> TIRunContext:
         return TIRunContext(
             dag_run=DagRun(
@@ -181,9 +186,11 @@ def make_ti_context() -> MakeTIContextCallable:
                 logical_date=logical_date,  # type: ignore
                 data_interval_start=data_interval_start,  # type: ignore
                 data_interval_end=data_interval_end,  # type: ignore
+                clear_number=clear_number,  # type: ignore
                 start_date=start_date,  # type: ignore
                 run_type=run_type,  # type: ignore
             ),
+            task_reschedule_count=task_reschedule_count,
             max_tries=0,
         )
 
@@ -202,6 +209,7 @@ def make_ti_context_dict(make_ti_context: 
MakeTIContextCallable) -> MakeTIContex
         data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
         start_date: str | datetime = "2024-12-01T00:00:00Z",
         run_type: str = "manual",
+        task_reschedule_count: int = 0,
     ) -> dict[str, Any]:
         context = make_ti_context(
             dag_id=dag_id,
@@ -211,6 +219,7 @@ def make_ti_context_dict(make_ti_context: 
MakeTIContextCallable) -> MakeTIContex
             data_interval_end=data_interval_end,
             start_date=start_date,
             run_type=run_type,
+            task_reschedule_count=task_reschedule_count,
         )
         return context.model_dump(exclude_unset=True, mode="json")
 
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 12c3455ccfe..40b4445e967 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -125,6 +125,7 @@ class TestWatchedSubprocess:
                 dag_id="c",
                 run_id="d",
                 try_number=1,
+                start_date=instant,
             ),
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
@@ -193,6 +194,7 @@ class TestWatchedSubprocess:
                 dag_id="c",
                 run_id="d",
                 try_number=1,
+                start_date=tz.utcnow(),
             ),
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
@@ -212,11 +214,7 @@ class TestWatchedSubprocess:
             dag_rel_path=os.devnull,
             bundle_info=FAKE_BUNDLE,
             what=TaskInstance(
-                id=uuid7(),
-                task_id="b",
-                dag_id="c",
-                run_id="d",
-                try_number=1,
+                id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=tz.utcnow()
             ),
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
@@ -249,11 +247,7 @@ class TestWatchedSubprocess:
             dag_rel_path=os.devnull,
             bundle_info=FAKE_BUNDLE,
             what=TaskInstance(
-                id=ti_id,
-                task_id="b",
-                dag_id="c",
-                run_id="d",
-                try_number=1,
+                id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=timezone.utcnow()
             ),
             client=sdk_client.Client(base_url="", dry_run=True, token=""),
             target=subprocess_main,
@@ -276,6 +270,7 @@ class TestWatchedSubprocess:
             dag_id="super_basic_run",
             run_id="c",
             try_number=1,
+            start_date=instant,
         )
         bundle_info = BundleInfo.model_construct(name="my-bundle", 
version=None)
         with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, 
bundle_info.name)):
@@ -307,9 +302,15 @@ class TestWatchedSubprocess:
         This includes ensuring the task starts and executes successfully, and 
that the task is deferred (via
         the API client) with the expected parameters.
         """
+        instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
 
         ti = TaskInstance(
-            id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", 
run_id="d", try_number=1
+            id=uuid7(),
+            task_id="async",
+            dag_id="super_basic_deferred_run",
+            run_id="d",
+            try_number=1,
+            start_date=instant,
         )
 
         # Create a mock client to assert calls to the client
@@ -317,7 +318,6 @@ class TestWatchedSubprocess:
         mock_client = mocker.Mock(spec=sdk_client.Client)
         mock_client.task_instances.start.return_value = make_ti_context()
 
-        instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
         time_machine.move_to(instant, tick=False)
 
         bundle_info = BundleInfo.model_construct(name="my-bundle", 
version=None)
@@ -357,7 +357,9 @@ class TestWatchedSubprocess:
 
     def test_supervisor_handles_already_running_task(self):
         """Test that Supervisor prevents starting a Task Instance that is 
already running."""
-        ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", 
try_number=1)
+        ti = TaskInstance(
+            id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=tz.utcnow()
+        )
 
         # Mock API Server response indicating the TI is already running
         # The API Server would return a 409 Conflict status code if the TI is 
not
@@ -433,7 +435,9 @@ class TestWatchedSubprocess:
 
         proc = ActivitySubprocess.start(
             dag_rel_path=os.devnull,
-            what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            what=TaskInstance(
+                id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=tz.utcnow()
+            ),
             client=make_client(transport=httpx.MockTransport(handle_request)),
             target=subprocess_main,
             bundle_info=FAKE_BUNDLE,
@@ -704,9 +708,11 @@ class TestWatchedSubprocessKill:
         ti_id = uuid7()
 
         proc = ActivitySubprocess.start(
-            dag_rel_path=os.devnull,
+            path=os.devnull,
             bundle_info=FAKE_BUNDLE,
-            what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            what=TaskInstance(
+                id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=tz.utcnow()
+            ),
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
         )
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index f716317ad24..680032c5889 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -36,6 +36,8 @@ from airflow.exceptions import (
     AirflowSkipException,
     AirflowTaskTerminated,
 )
+from airflow.listeners import hookimpl
+from airflow.listeners.listener import get_listener_manager
 from airflow.sdk import DAG, BaseOperator, Connection, get_current_context
 from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
 from airflow.sdk.definitions.variable import Variable
@@ -62,6 +64,7 @@ from airflow.sdk.execution_time.task_runner import (
     startup,
 )
 from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
 
 FAKE_BUNDLE = BundleInfo.model_construct(name="anything", version="any")
 
@@ -74,6 +77,42 @@ def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG:
     return dag
 
 
+@pytest.fixture
+def mocked_parse(spy_agency):
+    """
+    Fixture to set up an inline DAG and use it in a stubbed `parse` function. 
Use this fixture if you
+    want to isolate and test `parse` or `run` logic without having to define a 
DAG file.
+
+    This fixture returns a helper function `set_dag` that:
+    1. Creates an in line DAG with the given `dag_id` and `task` (limited to 
one task)
+    2. Constructs a `RuntimeTaskInstance` based on the provided 
`StartupDetails` and task.
+    3. Stubs the `parse` function using `spy_agency`, to return the mocked 
`RuntimeTaskInstance`.
+
+    After adding the fixture in your test function signature, you can use it 
like this ::
+
+            mocked_parse(
+                StartupDetails(
+                    ti=TaskInstance(id=uuid7(), task_id="hello", 
dag_id="super_basic_run", run_id="c", try_number=1),
+                    file="",
+                    requests_fd=0,
+                ),
+                "example_dag_id",
+                CustomOperator(task_id="hello"),
+            )
+    """
+
+    def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> 
RuntimeTaskInstance:
+        dag = get_inline_dag(dag_id, task)
+        t = dag.task_dict[task.task_id]
+        ti = RuntimeTaskInstance.model_construct(
+            **what.ti.model_dump(exclude_unset=True), task=t, 
_ti_context_from_server=what.ti_context
+        )
+        spy_agency.spy_on(parse, call_fake=lambda _: ti)
+        return ti
+
+    return set_dag
+
+
 class CustomOperator(BaseOperator):
     def execute(self, context):
         task_id = context["task_instance"].task_id
@@ -91,7 +130,8 @@ class TestCommsDecoder:
 
         w.makefile("wb").write(
             b'{"type":"StartupDetails", "ti": {'
-            b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", 
"try_number": 1, "run_id": "b", "dag_id": "c" }, '
+            b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", 
"try_number": 1, "run_id": "b", '
+            b'"dag_id": "c", "start_date": "2024-12-01T01:00:00Z" }, '
             
b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",'
             
b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",'
             
b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},'
@@ -110,6 +150,7 @@ class TestCommsDecoder:
         assert msg.ti.dag_id == "c"
         assert msg.dag_rel_path == "/dev/null"
         assert msg.bundle_info == BundleInfo.model_construct(name="any-name", 
version="any-version")
+        assert msg.ti.start_date == timezone.datetime(2024, 12, 1, 1, 0)
 
         # Since this was a StartupDetails message, the decoder should open the 
other socket
         assert decoder.request_socket is not None
@@ -120,7 +161,14 @@ class TestCommsDecoder:
 def test_parse(test_dags_dir: Path, make_ti_context):
     """Test that checks parsing of a basic dag with an un-mocked parse."""
     what = StartupDetails(
-        ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", 
run_id="c", try_number=1),
+        ti=TaskInstance(
+            id=uuid7(),
+            task_id="a",
+            dag_id="super_basic",
+            run_id="c",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        ),
         dag_rel_path="super_basic.py",
         bundle_info=BundleInfo.model_construct(name="my-bundle", version=None),
         requests_fd=0,
@@ -346,7 +394,12 @@ def test_startup_basic_templated_dag(mocked_parse, 
make_ti_context, mock_supervi
 
     what = StartupDetails(
         ti=TaskInstance(
-            id=uuid7(), task_id="templated_task", 
dag_id="basic_templated_dag", run_id="c", try_number=1
+            id=uuid7(),
+            task_id="templated_task",
+            dag_id="basic_templated_dag",
+            run_id="c",
+            try_number=1,
+            start_date=timezone.datetime(2024, 12, 3, 10, 0),
         ),
         bundle_info=FAKE_BUNDLE,
         dag_rel_path="",
@@ -414,9 +467,17 @@ def test_startup_and_run_dag_with_rtif(
                 print(key, getattr(self, key))
 
     task = CustomOperator(task_id="templated_task")
+    instant = timezone.datetime(2024, 12, 3, 10, 0)
 
     what = StartupDetails(
-        ti=TaskInstance(id=uuid7(), task_id="templated_task", 
dag_id="basic_dag", run_id="c", try_number=1),
+        ti=TaskInstance(
+            id=uuid7(),
+            task_id="templated_task",
+            dag_id="basic_dag",
+            run_id="c",
+            try_number=1,
+            start_date=instant,
+        ),
         dag_rel_path="",
         bundle_info=FAKE_BUNDLE,
         requests_fd=0,
@@ -424,7 +485,6 @@ def test_startup_and_run_dag_with_rtif(
     )
     ti = mocked_parse(what, "basic_dag", task)
 
-    instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
 
     mock_supervisor_comms.get_message.return_value = what
@@ -654,6 +714,7 @@ class TestRuntimeTaskInstance:
             "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0),
             "data_interval_start": timezone.datetime(2024, 12, 1, 0, 0, 0),
             "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0),
+            "task_reschedule_count": 0,
             "ds": "2024-12-01",
             "ds_nodash": "20241201",
             "expanded_ti_count": None,
@@ -950,3 +1011,91 @@ class TestXComAfterTaskExecution:
         assert str(exc_info.value) == (
             f"Returned dictionary keys must be strings when using 
multiple_outputs, found 2 ({int}) instead"
         )
+
+
+class TestTaskRunnerCallsListeners:
+    class CustomListener:
+        def __init__(self):
+            self.state = []
+
+        @hookimpl
+        def on_task_instance_running(self, previous_state, task_instance):
+            self.state.append(TaskInstanceState.RUNNING)
+
+        @hookimpl
+        def on_task_instance_success(self, previous_state, task_instance):
+            self.state.append(TaskInstanceState.SUCCESS)
+
+        @hookimpl
+        def on_task_instance_failed(self, previous_state, task_instance):
+            self.state.append(TaskInstanceState.FAILED)
+
+    @pytest.fixture(autouse=True)
+    def clean_listener_manager(self):
+        lm = get_listener_manager()
+        lm.clear()
+        yield
+        lm = get_listener_manager()
+        lm.clear()
+
+    def test_task_runner_calls_listeners_success(self, mocked_parse, 
mock_supervisor_comms):
+        listener = self.CustomListener()
+        get_listener_manager().add_listener(listener)
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                self.value = "something"
+
+        task = CustomOperator(
+            task_id="test_task_runner_calls_listeners", do_xcom_push=True, 
multiple_outputs=True
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id=task.task_id,
+            dag_id=dag.dag_id,
+            run_id="test_run",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        )
+
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
+
+        run(runtime_ti, log=mock.MagicMock())
+
+        assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.SUCCESS]
+
+    @pytest.mark.parametrize(
+        "exception",
+        [
+            ValueError("oops"),
+            SystemExit("oops"),
+            AirflowException("oops"),
+        ],
+    )
+    def test_task_runner_calls_listeners_failed(self, mocked_parse, 
mock_supervisor_comms, exception):
+        listener = self.CustomListener()
+        get_listener_manager().add_listener(listener)
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                raise exception
+
+        task = CustomOperator(
+            task_id="test_task_runner_calls_listeners_failed", 
do_xcom_push=True, multiple_outputs=True
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id=task.task_id,
+            dag_id=dag.dag_id,
+            run_id="test_run",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        )
+
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
+
+        run(runtime_ti, log=mock.MagicMock())
+
+        assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.FAILED]
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index e6da0f3a192..bcb30c4b986 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -80,6 +80,7 @@ class TestTIRunState:
             "dag_run": {
                 "dag_id": "dag",
                 "run_id": "test",
+                "clear_number": 0,
                 "logical_date": instant_str,
                 "data_interval_start": 
instant.subtract(days=1).to_iso8601_string(),
                 "data_interval_end": instant_str,
@@ -88,6 +89,7 @@ class TestTIRunState:
                 "run_type": "manual",
                 "conf": {},
             },
+            "task_reschedule_count": 0,
             "max_tries": 0,
             "variables": [],
             "connections": [],
diff --git a/tests/callbacks/test_callback_requests.py 
b/tests/callbacks/test_callback_requests.py
index 68362d36239..b07b4e742b5 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -64,6 +64,7 @@ class TestCallbackRequest:
                 run_id="fake_run",
                 state=State.RUNNING,
             )
+            ti.start_date = timezone.utcnow()
 
             input = TaskCallbackRequest(
                 full_filepath="filepath",
diff --git a/tests/executors/test_local_executor.py 
b/tests/executors/test_local_executor.py
index 673457509ca..ec72863b20e 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -27,6 +27,7 @@ from uuid6 import uuid7
 
 from airflow.executors import workloads
 from airflow.executors.local_executor import LocalExecutor
+from airflow.utils import timezone
 from airflow.utils.state import State
 
 pytestmark = pytest.mark.db_test
@@ -63,6 +64,8 @@ class TestLocalExecutor:
                 pool_slots=1,
                 queue="default",
                 priority_weight=1,
+                map_index=-1,
+                start_date=timezone.utcnow(),
             )
             for i in range(self.TEST_SUCCESS_COMMANDS)
         ]
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 0ea27f97bb8..06000886e89 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -5667,6 +5667,7 @@ class TestSchedulerJob:
                 ti = TaskInstance(task, run_id=dag_run.run_id, 
state=State.RUNNING)
 
                 ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6)
+                ti.start_date = timezone.utcnow() - timedelta(minutes=10)
                 ti.queued_by_job_id = 999
 
                 session.add(ti)
@@ -5782,6 +5783,7 @@ class TestSchedulerJob:
 
             ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING)
             ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6)
+            ti.start_date = timezone.utcnow() - timedelta(minutes=10)
 
             # TODO: If there was an actual Relationship between TI and Job
             # we wouldn't need this extra commit
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index 8632bf2726d..38d46a3275b 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -81,7 +81,7 @@ class TestCliUtil:
         assert os.path.join(settings.DAGS_FOLDER, "abc") == 
cli.process_subdir("DAGS_FOLDER/abc")
 
     def test_get_dags(self):
-        dags = cli.get_dags(None, "example_bash_operator")
+        dags = cli.get_dags(None, "test_example_bash_operator")
         assert len(dags) == 1
 
         with pytest.raises(AirflowException):


Reply via email to