This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch tasksdk-call-listeners in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0d95771ae5b8cacf15de7b3c0614fb1c8119bc6f Author: Maciej Obuchowski <obuchowski.mac...@gmail.com> AuthorDate: Wed Jan 8 13:20:23 2025 +0100 change listener API Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com> --- airflow/example_dags/plugins/event_listener.py | 44 +++++++++----------------- airflow/listeners/listener.py | 8 ++++- airflow/listeners/spec/taskinstance.py | 15 +++------ airflow/models/taskinstance.py | 8 ++--- tests/listeners/class_listener.py | 8 ++--- tests/listeners/empty_listener.py | 2 +- tests/listeners/file_write_listener.py | 8 ++--- tests/listeners/full_listener.py | 6 ++-- tests/listeners/partial_listener.py | 2 +- tests/listeners/slow_listener.py | 2 +- tests/listeners/throwing_listener.py | 2 +- tests/listeners/very_slow_listener.py | 2 +- tests/listeners/xcom_listener.py | 4 +-- tests/plugins/test_plugins_manager.py | 29 +++++++++-------- 14 files changed, 62 insertions(+), 78 deletions(-) diff --git a/airflow/example_dags/plugins/event_listener.py b/airflow/example_dags/plugins/event_listener.py index 6d9fe2ff117..b0001b0bc7e 100644 --- a/airflow/example_dags/plugins/event_listener.py +++ b/airflow/example_dags/plugins/event_listener.py @@ -23,13 +23,13 @@ from airflow.listeners import hookimpl if TYPE_CHECKING: from airflow.models.dagrun import DagRun - from airflow.models.taskinstance import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.utils.state import TaskInstanceState # [START howto_listen_ti_running_task] @hookimpl -def on_task_instance_running(previous_state: TaskInstanceState, task_instance: TaskInstance, session): +def on_task_instance_running(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance): """ This method is called when task state changes to RUNNING. Through callback, parameters like previous_task_state, task_instance object can be accessed. @@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T print("Task instance is in running state") print(" Previous state of the Task instance:", previous_state) - state: TaskInstanceState = task_instance.state name: str = task_instance.task_id - start_date = task_instance.start_date - dagrun = task_instance.dag_run - dagrun_status = dagrun.state + context = task_instance.get_template_context() - task = task_instance.task + task = context["task"] if TYPE_CHECKING: assert task @@ -55,8 +52,8 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T dag_name = None if dag: dag_name = dag.dag_id - print(f"Current task name:{name} state:{state} start_date:{start_date}") - print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}") + print(f"Current task name:{name}") + print(f"Dag name:{dag_name}") # [END howto_listen_ti_running_task] @@ -64,7 +61,7 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T # [START howto_listen_ti_success_task] @hookimpl -def on_task_instance_success(previous_state: TaskInstanceState, task_instance: TaskInstance, session): +def on_task_instance_success(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance): """ This method is called when task state changes to SUCCESS. Through callback, parameters like previous_task_state, task_instance object can be accessed. @@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T print("Task instance in success state") print(" Previous state of the Task instance:", previous_state) - dag_id = task_instance.dag_id - hostname = task_instance.hostname - operator = task_instance.operator + context = task_instance.get_template_context() + operator = context["task"] - dagrun = task_instance.dag_run - queued_at = dagrun.queued_at - print(f"Dag name:{dag_id} queued_at:{queued_at}") - print(f"Task hostname:{hostname} operator:{operator}") + print(f"Task operator:{operator}") # [END howto_listen_ti_success_task] @@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T # [START howto_listen_ti_failure_task] @hookimpl def on_task_instance_failed( - previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session + previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, error: None | str | BaseException ): """ This method is called when task state changes to FAILED. @@ -100,21 +93,14 @@ def on_task_instance_failed( """ print("Task instance in failure state") - start_date = task_instance.start_date - end_date = task_instance.end_date - duration = task_instance.duration - - dagrun = task_instance.dag_run - - task = task_instance.task + context = task_instance.get_template_context() + task = context["task"] if TYPE_CHECKING: assert task - dag = task.dag - - print(f"Task start:{start_date} end:{end_date} duration:{duration}") - print(f"Task:{task} dag:{dag} dagrun:{dagrun}") + print("Task start") + print(f"Task:{task}") if error: print(f"Failure caused by {error}") diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py index 5e8fba55d43..11918527ef2 100644 --- a/airflow/listeners/listener.py +++ b/airflow/listeners/listener.py @@ -46,7 +46,13 @@ class ListenerManager: """Manage listener registration and provides hook property for calling them.""" def __init__(self): - from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance + from airflow.listeners.spec import ( + asset, + dagrun, + importerrors, + lifecycle, + taskinstance, + ) self.pm = pluggy.PluginManager("airflow") self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall) diff --git a/airflow/listeners/spec/taskinstance.py b/airflow/listeners/spec/taskinstance.py index f012de0aac8..d66d6c83ce3 100644 --- a/airflow/listeners/spec/taskinstance.py +++ b/airflow/listeners/spec/taskinstance.py @@ -22,33 +22,26 @@ from typing import TYPE_CHECKING from pluggy import HookspecMarker if TYPE_CHECKING: - from sqlalchemy.orm.session import Session - - from airflow.models.taskinstance import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.utils.state import TaskInstanceState hookspec = HookspecMarker("airflow") @hookspec -def on_task_instance_running( - previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None -): +def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance): """Execute when task state changes to RUNNING. previous_state can be None.""" @hookspec -def on_task_instance_success( - previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None -): +def on_task_instance_success(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance): """Execute when task state changes to SUCCESS. previous_state can be None.""" @hookspec def on_task_instance_failed( previous_state: TaskInstanceState | None, - task_instance: TaskInstance, + task_instance: RuntimeTaskInstance, error: None | str | BaseException, - session: Session | None, ): """Execute when task state changes to FAIL. previous_state can be None.""" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3fa4e8fb4..5653f99f166 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -356,12 +356,12 @@ def _run_raw_task( if not test_mode: _add_log(event=ti.state, task_instance=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: - ti._register_asset_changes(events=context["outlet_events"], session=session) + ti._register_asset_changes(events=context["outlet_events"]) TaskInstance.save_to_db(ti=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session + previous_state=TaskInstanceState.RUNNING, task_instance=ti ) return None @@ -2907,7 +2907,7 @@ class TaskInstance(Base, LoggingMixin): # Run on_task_instance_running event get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session + previous_state=TaskInstanceState.QUEUED, task_instance=self ) def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: @@ -3149,7 +3149,7 @@ class TaskInstance(Base, LoggingMixin): callbacks = task.on_retry_callback if task else None get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session + previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error ) return { diff --git a/tests/listeners/class_listener.py b/tests/listeners/class_listener.py index b39f7278546..de235abbd40 100644 --- a/tests/listeners/class_listener.py +++ b/tests/listeners/class_listener.py @@ -85,17 +85,15 @@ elif AIRFLOW_V_2_10_PLUS: self.state.append(DagRunState.SUCCESS) @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): self.state.append(TaskInstanceState.RUNNING) @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): self.state.append(TaskInstanceState.SUCCESS) @hookimpl - def on_task_instance_failed( - self, previous_state, task_instance, error: None | str | BaseException, session - ): + def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException): self.state.append(TaskInstanceState.FAILED) else: diff --git a/tests/listeners/empty_listener.py b/tests/listeners/empty_listener.py index 0b298e95fe6..0a69f9ec1fa 100644 --- a/tests/listeners/empty_listener.py +++ b/tests/listeners/empty_listener.py @@ -21,7 +21,7 @@ from airflow.listeners import hookimpl @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): pass diff --git a/tests/listeners/file_write_listener.py b/tests/listeners/file_write_listener.py index 0ca026da716..c542ccacab5 100644 --- a/tests/listeners/file_write_listener.py +++ b/tests/listeners/file_write_listener.py @@ -34,17 +34,15 @@ class FileWriteListener: f.write(line + "\n") @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): self.write("on_task_instance_running") @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): self.write("on_task_instance_success") @hookimpl - def on_task_instance_failed( - self, previous_state, task_instance, error: None | str | BaseException, session - ): + def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException): self.write("on_task_instance_failed") @hookimpl diff --git a/tests/listeners/full_listener.py b/tests/listeners/full_listener.py index 229fdab6762..50701c822e6 100644 --- a/tests/listeners/full_listener.py +++ b/tests/listeners/full_listener.py @@ -40,17 +40,17 @@ def before_stopping(component): @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): state.append(TaskInstanceState.RUNNING) @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): state.append(TaskInstanceState.SUCCESS) @hookimpl -def on_task_instance_failed(previous_state, task_instance, error: None | str | BaseException, session): +def on_task_instance_failed(previous_state, task_instance, error: None | str | BaseException): state.append(TaskInstanceState.FAILED) diff --git a/tests/listeners/partial_listener.py b/tests/listeners/partial_listener.py index b4027e28756..2bf1d117745 100644 --- a/tests/listeners/partial_listener.py +++ b/tests/listeners/partial_listener.py @@ -24,7 +24,7 @@ state: list[State] = [] @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): state.append(State.RUNNING) diff --git a/tests/listeners/slow_listener.py b/tests/listeners/slow_listener.py index b366aa4d0cb..b585b19650a 100644 --- a/tests/listeners/slow_listener.py +++ b/tests/listeners/slow_listener.py @@ -22,5 +22,5 @@ from airflow.listeners import hookimpl @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): time.sleep(3) diff --git a/tests/listeners/throwing_listener.py b/tests/listeners/throwing_listener.py index ae7345d395a..eeb7d0ee6ed 100644 --- a/tests/listeners/throwing_listener.py +++ b/tests/listeners/throwing_listener.py @@ -21,7 +21,7 @@ from airflow.listeners import hookimpl @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): raise RuntimeError() diff --git a/tests/listeners/very_slow_listener.py b/tests/listeners/very_slow_listener.py index 28df43a2b8b..9752b8787b7 100644 --- a/tests/listeners/very_slow_listener.py +++ b/tests/listeners/very_slow_listener.py @@ -22,5 +22,5 @@ from airflow.listeners import hookimpl @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): time.sleep(10) diff --git a/tests/listeners/xcom_listener.py b/tests/listeners/xcom_listener.py index a7ffc191785..bbfbbba4e65 100644 --- a/tests/listeners/xcom_listener.py +++ b/tests/listeners/xcom_listener.py @@ -30,13 +30,13 @@ class XComListener: f.write(line + "\n") @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): task_instance.xcom_push(key="listener", value="listener") task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener") self.write("on_task_instance_running") @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): read = task_instance.xcom_pull(task_ids=self.task_id, key="listener") self.write("on_task_instance_success") self.write(read) diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 8618af26254..54ab88d2f60 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -358,19 +358,22 @@ class TestPluginsManager: def test_registering_plugin_listeners(self): from airflow import plugins_manager - with mock.patch("airflow.plugins_manager.plugins", []): - plugins_manager.load_plugins_from_plugin_directory() - plugins_manager.integrate_listener_plugins(get_listener_manager()) - - assert get_listener_manager().has_listeners - listeners = get_listener_manager().pm.get_plugins() - listener_names = [el.__name__ if inspect.ismodule(el) else qualname(el) for el in listeners] - # sort names as order of listeners is not guaranteed - assert sorted(listener_names) == [ - "airflow.example_dags.plugins.event_listener", - "tests.listeners.class_listener.ClassBasedListener", - "tests.listeners.empty_listener", - ] + try: + with mock.patch("airflow.plugins_manager.plugins", []): + plugins_manager.load_plugins_from_plugin_directory() + plugins_manager.integrate_listener_plugins(get_listener_manager()) + + assert get_listener_manager().has_listeners + listeners = get_listener_manager().pm.get_plugins() + listener_names = [el.__name__ if inspect.ismodule(el) else qualname(el) for el in listeners] + # sort names as order of listeners is not guaranteed + assert sorted(listener_names) == [ + "airflow.example_dags.plugins.event_listener", + "tests.listeners.class_listener.ClassBasedListener", + "tests.listeners.empty_listener", + ] + finally: + get_listener_manager().clear() def test_should_import_plugin_from_providers(self): from airflow import plugins_manager