This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch tasksdk-call-listeners
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0d95771ae5b8cacf15de7b3c0614fb1c8119bc6f
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Wed Jan 8 13:20:23 2025 +0100

    change listener API
    
    Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com>
---
 airflow/example_dags/plugins/event_listener.py | 44 +++++++++-----------------
 airflow/listeners/listener.py                  |  8 ++++-
 airflow/listeners/spec/taskinstance.py         | 15 +++------
 airflow/models/taskinstance.py                 |  8 ++---
 tests/listeners/class_listener.py              |  8 ++---
 tests/listeners/empty_listener.py              |  2 +-
 tests/listeners/file_write_listener.py         |  8 ++---
 tests/listeners/full_listener.py               |  6 ++--
 tests/listeners/partial_listener.py            |  2 +-
 tests/listeners/slow_listener.py               |  2 +-
 tests/listeners/throwing_listener.py           |  2 +-
 tests/listeners/very_slow_listener.py          |  2 +-
 tests/listeners/xcom_listener.py               |  4 +--
 tests/plugins/test_plugins_manager.py          | 29 +++++++++--------
 14 files changed, 62 insertions(+), 78 deletions(-)

diff --git a/airflow/example_dags/plugins/event_listener.py 
b/airflow/example_dags/plugins/event_listener.py
index 6d9fe2ff117..b0001b0bc7e 100644
--- a/airflow/example_dags/plugins/event_listener.py
+++ b/airflow/example_dags/plugins/event_listener.py
@@ -23,13 +23,13 @@ from airflow.listeners import hookimpl
 
 if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun
-    from airflow.models.taskinstance import TaskInstance
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
     from airflow.utils.state import TaskInstanceState
 
 
 # [START howto_listen_ti_running_task]
 @hookimpl
-def on_task_instance_running(previous_state: TaskInstanceState, task_instance: 
TaskInstance, session):
+def on_task_instance_running(previous_state: TaskInstanceState, task_instance: 
RuntimeTaskInstance):
     """
     This method is called when task state changes to RUNNING.
     Through callback, parameters like previous_task_state, task_instance 
object can be accessed.
@@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: 
TaskInstanceState, task_instance: T
     print("Task instance is in running state")
     print(" Previous state of the Task instance:", previous_state)
 
-    state: TaskInstanceState = task_instance.state
     name: str = task_instance.task_id
-    start_date = task_instance.start_date
 
-    dagrun = task_instance.dag_run
-    dagrun_status = dagrun.state
+    context = task_instance.get_template_context()
 
-    task = task_instance.task
+    task = context["task"]
 
     if TYPE_CHECKING:
         assert task
@@ -55,8 +52,8 @@ def on_task_instance_running(previous_state: 
TaskInstanceState, task_instance: T
     dag_name = None
     if dag:
         dag_name = dag.dag_id
-    print(f"Current task name:{name} state:{state} start_date:{start_date}")
-    print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}")
+    print(f"Current task name:{name}")
+    print(f"Dag name:{dag_name}")
 
 
 # [END howto_listen_ti_running_task]
@@ -64,7 +61,7 @@ def on_task_instance_running(previous_state: 
TaskInstanceState, task_instance: T
 
 # [START howto_listen_ti_success_task]
 @hookimpl
-def on_task_instance_success(previous_state: TaskInstanceState, task_instance: 
TaskInstance, session):
+def on_task_instance_success(previous_state: TaskInstanceState, task_instance: 
RuntimeTaskInstance):
     """
     This method is called when task state changes to SUCCESS.
     Through callback, parameters like previous_task_state, task_instance 
object can be accessed.
@@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: 
TaskInstanceState, task_instance: T
     print("Task instance in success state")
     print(" Previous state of the Task instance:", previous_state)
 
-    dag_id = task_instance.dag_id
-    hostname = task_instance.hostname
-    operator = task_instance.operator
+    context = task_instance.get_template_context()
+    operator = context["task"]
 
-    dagrun = task_instance.dag_run
-    queued_at = dagrun.queued_at
-    print(f"Dag name:{dag_id} queued_at:{queued_at}")
-    print(f"Task hostname:{hostname} operator:{operator}")
+    print(f"Task operator:{operator}")
 
 
 # [END howto_listen_ti_success_task]
@@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: 
TaskInstanceState, task_instance: T
 # [START howto_listen_ti_failure_task]
 @hookimpl
 def on_task_instance_failed(
-    previous_state: TaskInstanceState, task_instance: TaskInstance, error: 
None | str | BaseException, session
+    previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, 
error: None | str | BaseException
 ):
     """
     This method is called when task state changes to FAILED.
@@ -100,21 +93,14 @@ def on_task_instance_failed(
     """
     print("Task instance in failure state")
 
-    start_date = task_instance.start_date
-    end_date = task_instance.end_date
-    duration = task_instance.duration
-
-    dagrun = task_instance.dag_run
-
-    task = task_instance.task
+    context = task_instance.get_template_context()
+    task = context["task"]
 
     if TYPE_CHECKING:
         assert task
 
-    dag = task.dag
-
-    print(f"Task start:{start_date} end:{end_date} duration:{duration}")
-    print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
+    print("Task start")
+    print(f"Task:{task}")
     if error:
         print(f"Failure caused by {error}")
 
diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py
index 5e8fba55d43..11918527ef2 100644
--- a/airflow/listeners/listener.py
+++ b/airflow/listeners/listener.py
@@ -46,7 +46,13 @@ class ListenerManager:
     """Manage listener registration and provides hook property for calling 
them."""
 
     def __init__(self):
-        from airflow.listeners.spec import asset, dagrun, importerrors, 
lifecycle, taskinstance
+        from airflow.listeners.spec import (
+            asset,
+            dagrun,
+            importerrors,
+            lifecycle,
+            taskinstance,
+        )
 
         self.pm = pluggy.PluginManager("airflow")
         self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)
diff --git a/airflow/listeners/spec/taskinstance.py 
b/airflow/listeners/spec/taskinstance.py
index f012de0aac8..d66d6c83ce3 100644
--- a/airflow/listeners/spec/taskinstance.py
+++ b/airflow/listeners/spec/taskinstance.py
@@ -22,33 +22,26 @@ from typing import TYPE_CHECKING
 from pluggy import HookspecMarker
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm.session import Session
-
-    from airflow.models.taskinstance import TaskInstance
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
     from airflow.utils.state import TaskInstanceState
 
 hookspec = HookspecMarker("airflow")
 
 
 @hookspec
-def on_task_instance_running(
-    previous_state: TaskInstanceState | None, task_instance: TaskInstance, 
session: Session | None
-):
+def on_task_instance_running(previous_state: TaskInstanceState | None, 
task_instance: RuntimeTaskInstance):
     """Execute when task state changes to RUNNING. previous_state can be 
None."""
 
 
 @hookspec
-def on_task_instance_success(
-    previous_state: TaskInstanceState | None, task_instance: TaskInstance, 
session: Session | None
-):
+def on_task_instance_success(previous_state: TaskInstanceState | None, 
task_instance: RuntimeTaskInstance):
     """Execute when task state changes to SUCCESS. previous_state can be 
None."""
 
 
 @hookspec
 def on_task_instance_failed(
     previous_state: TaskInstanceState | None,
-    task_instance: TaskInstance,
+    task_instance: RuntimeTaskInstance,
     error: None | str | BaseException,
-    session: Session | None,
 ):
     """Execute when task state changes to FAIL. previous_state can be None."""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2f3fa4e8fb4..5653f99f166 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -356,12 +356,12 @@ def _run_raw_task(
         if not test_mode:
             _add_log(event=ti.state, task_instance=ti, session=session)
             if ti.state == TaskInstanceState.SUCCESS:
-                ti._register_asset_changes(events=context["outlet_events"], 
session=session)
+                ti._register_asset_changes(events=context["outlet_events"])
 
             TaskInstance.save_to_db(ti=ti, session=session)
             if ti.state == TaskInstanceState.SUCCESS:
                 get_listener_manager().hook.on_task_instance_success(
-                    previous_state=TaskInstanceState.RUNNING, 
task_instance=ti, session=session
+                    previous_state=TaskInstanceState.RUNNING, task_instance=ti
                 )
 
         return None
@@ -2907,7 +2907,7 @@ class TaskInstance(Base, LoggingMixin):
 
             # Run on_task_instance_running event
             get_listener_manager().hook.on_task_instance_running(
-                previous_state=TaskInstanceState.QUEUED, task_instance=self, 
session=session
+                previous_state=TaskInstanceState.QUEUED, task_instance=self
             )
 
             def _render_map_index(context: Context, *, jinja_env: 
jinja2.Environment | None) -> str | None:
@@ -3149,7 +3149,7 @@ class TaskInstance(Base, LoggingMixin):
             callbacks = task.on_retry_callback if task else None
 
         get_listener_manager().hook.on_task_instance_failed(
-            previous_state=TaskInstanceState.RUNNING, task_instance=ti, 
error=error, session=session
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti, 
error=error
         )
 
         return {
diff --git a/tests/listeners/class_listener.py 
b/tests/listeners/class_listener.py
index b39f7278546..de235abbd40 100644
--- a/tests/listeners/class_listener.py
+++ b/tests/listeners/class_listener.py
@@ -85,17 +85,15 @@ elif AIRFLOW_V_2_10_PLUS:
             self.state.append(DagRunState.SUCCESS)
 
         @hookimpl
-        def on_task_instance_running(self, previous_state, task_instance, 
session):
+        def on_task_instance_running(self, previous_state, task_instance):
             self.state.append(TaskInstanceState.RUNNING)
 
         @hookimpl
-        def on_task_instance_success(self, previous_state, task_instance, 
session):
+        def on_task_instance_success(self, previous_state, task_instance):
             self.state.append(TaskInstanceState.SUCCESS)
 
         @hookimpl
-        def on_task_instance_failed(
-            self, previous_state, task_instance, error: None | str | 
BaseException, session
-        ):
+        def on_task_instance_failed(self, previous_state, task_instance, 
error: None | str | BaseException):
             self.state.append(TaskInstanceState.FAILED)
 else:
 
diff --git a/tests/listeners/empty_listener.py 
b/tests/listeners/empty_listener.py
index 0b298e95fe6..0a69f9ec1fa 100644
--- a/tests/listeners/empty_listener.py
+++ b/tests/listeners/empty_listener.py
@@ -21,7 +21,7 @@ from airflow.listeners import hookimpl
 
 
 @hookimpl
-def on_task_instance_running(previous_state, task_instance, session):
+def on_task_instance_running(previous_state, task_instance):
     pass
 
 
diff --git a/tests/listeners/file_write_listener.py 
b/tests/listeners/file_write_listener.py
index 0ca026da716..c542ccacab5 100644
--- a/tests/listeners/file_write_listener.py
+++ b/tests/listeners/file_write_listener.py
@@ -34,17 +34,15 @@ class FileWriteListener:
             f.write(line + "\n")
 
     @hookimpl
-    def on_task_instance_running(self, previous_state, task_instance, session):
+    def on_task_instance_running(self, previous_state, task_instance):
         self.write("on_task_instance_running")
 
     @hookimpl
-    def on_task_instance_success(self, previous_state, task_instance, session):
+    def on_task_instance_success(self, previous_state, task_instance):
         self.write("on_task_instance_success")
 
     @hookimpl
-    def on_task_instance_failed(
-        self, previous_state, task_instance, error: None | str | 
BaseException, session
-    ):
+    def on_task_instance_failed(self, previous_state, task_instance, error: 
None | str | BaseException):
         self.write("on_task_instance_failed")
 
     @hookimpl
diff --git a/tests/listeners/full_listener.py b/tests/listeners/full_listener.py
index 229fdab6762..50701c822e6 100644
--- a/tests/listeners/full_listener.py
+++ b/tests/listeners/full_listener.py
@@ -40,17 +40,17 @@ def before_stopping(component):
 
 
 @hookimpl
-def on_task_instance_running(previous_state, task_instance, session):
+def on_task_instance_running(previous_state, task_instance):
     state.append(TaskInstanceState.RUNNING)
 
 
 @hookimpl
-def on_task_instance_success(previous_state, task_instance, session):
+def on_task_instance_success(previous_state, task_instance):
     state.append(TaskInstanceState.SUCCESS)
 
 
 @hookimpl
-def on_task_instance_failed(previous_state, task_instance, error: None | str | 
BaseException, session):
+def on_task_instance_failed(previous_state, task_instance, error: None | str | 
BaseException):
     state.append(TaskInstanceState.FAILED)
 
 
diff --git a/tests/listeners/partial_listener.py 
b/tests/listeners/partial_listener.py
index b4027e28756..2bf1d117745 100644
--- a/tests/listeners/partial_listener.py
+++ b/tests/listeners/partial_listener.py
@@ -24,7 +24,7 @@ state: list[State] = []
 
 
 @hookimpl
-def on_task_instance_running(previous_state, task_instance, session):
+def on_task_instance_running(previous_state, task_instance):
     state.append(State.RUNNING)
 
 
diff --git a/tests/listeners/slow_listener.py b/tests/listeners/slow_listener.py
index b366aa4d0cb..b585b19650a 100644
--- a/tests/listeners/slow_listener.py
+++ b/tests/listeners/slow_listener.py
@@ -22,5 +22,5 @@ from airflow.listeners import hookimpl
 
 
 @hookimpl
-def on_task_instance_success(previous_state, task_instance, session):
+def on_task_instance_success(previous_state, task_instance):
     time.sleep(3)
diff --git a/tests/listeners/throwing_listener.py 
b/tests/listeners/throwing_listener.py
index ae7345d395a..eeb7d0ee6ed 100644
--- a/tests/listeners/throwing_listener.py
+++ b/tests/listeners/throwing_listener.py
@@ -21,7 +21,7 @@ from airflow.listeners import hookimpl
 
 
 @hookimpl
-def on_task_instance_success(previous_state, task_instance, session):
+def on_task_instance_success(previous_state, task_instance):
     raise RuntimeError()
 
 
diff --git a/tests/listeners/very_slow_listener.py 
b/tests/listeners/very_slow_listener.py
index 28df43a2b8b..9752b8787b7 100644
--- a/tests/listeners/very_slow_listener.py
+++ b/tests/listeners/very_slow_listener.py
@@ -22,5 +22,5 @@ from airflow.listeners import hookimpl
 
 
 @hookimpl
-def on_task_instance_success(previous_state, task_instance, session):
+def on_task_instance_success(previous_state, task_instance):
     time.sleep(10)
diff --git a/tests/listeners/xcom_listener.py b/tests/listeners/xcom_listener.py
index a7ffc191785..bbfbbba4e65 100644
--- a/tests/listeners/xcom_listener.py
+++ b/tests/listeners/xcom_listener.py
@@ -30,13 +30,13 @@ class XComListener:
             f.write(line + "\n")
 
     @hookimpl
-    def on_task_instance_running(self, previous_state, task_instance, session):
+    def on_task_instance_running(self, previous_state, task_instance):
         task_instance.xcom_push(key="listener", value="listener")
         task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener")
         self.write("on_task_instance_running")
 
     @hookimpl
-    def on_task_instance_success(self, previous_state, task_instance, session):
+    def on_task_instance_success(self, previous_state, task_instance):
         read = task_instance.xcom_pull(task_ids=self.task_id, key="listener")
         self.write("on_task_instance_success")
         self.write(read)
diff --git a/tests/plugins/test_plugins_manager.py 
b/tests/plugins/test_plugins_manager.py
index 8618af26254..54ab88d2f60 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -358,19 +358,22 @@ class TestPluginsManager:
     def test_registering_plugin_listeners(self):
         from airflow import plugins_manager
 
-        with mock.patch("airflow.plugins_manager.plugins", []):
-            plugins_manager.load_plugins_from_plugin_directory()
-            plugins_manager.integrate_listener_plugins(get_listener_manager())
-
-            assert get_listener_manager().has_listeners
-            listeners = get_listener_manager().pm.get_plugins()
-            listener_names = [el.__name__ if inspect.ismodule(el) else 
qualname(el) for el in listeners]
-            # sort names as order of listeners is not guaranteed
-            assert sorted(listener_names) == [
-                "airflow.example_dags.plugins.event_listener",
-                "tests.listeners.class_listener.ClassBasedListener",
-                "tests.listeners.empty_listener",
-            ]
+        try:
+            with mock.patch("airflow.plugins_manager.plugins", []):
+                plugins_manager.load_plugins_from_plugin_directory()
+                
plugins_manager.integrate_listener_plugins(get_listener_manager())
+
+                assert get_listener_manager().has_listeners
+                listeners = get_listener_manager().pm.get_plugins()
+                listener_names = [el.__name__ if inspect.ismodule(el) else 
qualname(el) for el in listeners]
+                # sort names as order of listeners is not guaranteed
+                assert sorted(listener_names) == [
+                    "airflow.example_dags.plugins.event_listener",
+                    "tests.listeners.class_listener.ClassBasedListener",
+                    "tests.listeners.empty_listener",
+                ]
+        finally:
+            get_listener_manager().clear()
 
     def test_should_import_plugin_from_providers(self):
         from airflow import plugins_manager

Reply via email to