This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6afde788f58 dag run state change endpoints notify listeners about 
state change (#45652)
6afde788f58 is described below

commit 6afde788f58a980507e93962a3c619d76981cdcf
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Wed Jan 15 17:24:34 2025 +0100

    dag run state change endpoints notify listeners about state change (#45652)
    
    Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com>
---
 .../api_fastapi/core_api/routes/public/dag_run.py  |  4 ++
 .../core_api/routes/public/test_dag_run.py         | 24 +++++++++++
 tests/listeners/class_listener.py                  | 47 +++++++++++++++++++++-
 3 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py 
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 9b8cfd6727b..cb846caaa2e 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -61,6 +61,7 @@ from airflow.api_fastapi.core_api.datamodels.task_instances 
import (
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
 from airflow.exceptions import ParamValidationError
+from airflow.listeners.listener import get_listener_manager
 from airflow.models import DAG, DagModel, DagRun
 from airflow.models.dag_version import DagVersion
 from airflow.timetables.base import DataInterval
@@ -159,10 +160,13 @@ def patch_dag_run(
             attr_value = getattr(patch_body, "state")
             if attr_value == DAGRunPatchStates.SUCCESS:
                 set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+                
get_listener_manager().hook.on_dag_run_success(dag_run=dag_run, msg="")
             elif attr_value == DAGRunPatchStates.QUEUED:
                 set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+                # Not notifying on queued - only notifying on RUNNING, this is 
happening in scheduler
             elif attr_value == DAGRunPatchStates.FAILED:
                 set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+                get_listener_manager().hook.on_dag_run_failed(dag_run=dag_run, 
msg="")
         elif attr_name == "note":
             # Once Authentication is implemented in this FastAPI app,
             # user id will be added when updating dag run note
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py 
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 9b7c9211fdd..fc171150534 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -24,6 +24,7 @@ import pytest
 import time_machine
 from sqlalchemy import select
 
+from airflow.listeners.listener import get_listener_manager
 from airflow.models import DagModel, DagRun
 from airflow.models.asset import AssetEvent, AssetModel
 from airflow.models.param import Param
@@ -943,6 +944,29 @@ class TestPatchDagRun:
         body = response.json()
         assert body["detail"][0]["msg"] == "Input should be 'queued', 
'success' or 'failed'"
 
+    @pytest.fixture(autouse=True)
+    def clean_listener_manager(self):
+        get_listener_manager().clear()
+        yield
+        get_listener_manager().clear()
+
+    @pytest.mark.parametrize(
+        "state, listener_state",
+        [
+            ("queued", []),
+            ("success", [DagRunState.SUCCESS]),
+            ("failed", [DagRunState.FAILED]),
+        ],
+    )
+    def test_patch_dag_run_notifies_listeners(self, test_client, state, 
listener_state):
+        from tests.listeners.class_listener import ClassBasedListener
+
+        listener = ClassBasedListener()
+        get_listener_manager().add_listener(listener)
+        response = 
test_client.patch(f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", 
json={"state": state})
+        assert response.status_code == 200
+        assert listener.state == listener_state
+
 
 class TestDeleteDagRun:
     def test_delete_dag_run(self, test_client):
diff --git a/tests/listeners/class_listener.py 
b/tests/listeners/class_listener.py
index 90ff6ab975e..b39f7278546 100644
--- a/tests/listeners/class_listener.py
+++ b/tests/listeners/class_listener.py
@@ -20,9 +20,9 @@ from __future__ import annotations
 from airflow.listeners import hookimpl
 from airflow.utils.state import DagRunState, TaskInstanceState
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, 
AIRFLOW_V_3_0_PLUS
 
-if AIRFLOW_V_2_10_PLUS:
+if AIRFLOW_V_3_0_PLUS:
 
     class ClassBasedListener:
         def __init__(self):
@@ -41,6 +41,49 @@ if AIRFLOW_V_2_10_PLUS:
             stopped_component = component
             self.state.append(DagRunState.SUCCESS)
 
+        @hookimpl
+        def on_task_instance_running(self, previous_state, task_instance):
+            self.state.append(TaskInstanceState.RUNNING)
+
+        @hookimpl
+        def on_task_instance_success(self, previous_state, task_instance):
+            self.state.append(TaskInstanceState.SUCCESS)
+
+        @hookimpl
+        def on_task_instance_failed(self, previous_state, task_instance, 
error: None | str | BaseException):
+            self.state.append(TaskInstanceState.FAILED)
+
+        @hookimpl
+        def on_dag_run_running(self, dag_run, msg: str):
+            self.state.append(DagRunState.RUNNING)
+
+        @hookimpl
+        def on_dag_run_success(self, dag_run, msg: str):
+            self.state.append(DagRunState.SUCCESS)
+
+        @hookimpl
+        def on_dag_run_failed(self, dag_run, msg: str):
+            self.state.append(DagRunState.FAILED)
+
+elif AIRFLOW_V_2_10_PLUS:
+
+    class ClassBasedListener:  # type: ignore[no-redef]
+        def __init__(self):
+            self.started_component = None
+            self.stopped_component = None
+            self.state = []
+
+        @hookimpl
+        def on_starting(self, component):
+            self.started_component = component
+            self.state.append(DagRunState.RUNNING)
+
+        @hookimpl
+        def before_stopping(self, component):
+            global stopped_component
+            stopped_component = component
+            self.state.append(DagRunState.SUCCESS)
+
         @hookimpl
         def on_task_instance_running(self, previous_state, task_instance, 
session):
             self.state.append(TaskInstanceState.RUNNING)

Reply via email to