This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-9-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c6c589f5a495a19769e3274e7bfe04dde50f62a7 Author: Steven Schaerer <53116297+stevenschae...@users.noreply.github.com> AuthorDate: Thu Apr 4 21:55:28 2024 +0200 Use async db calls in WorkflowTrigger (#38689) * Use async db calls in WorkflowTrigger * address PR comments * deprecate TaskStateTrigger with proper category (cherry picked from commit e6eec0cfad424e402fe2a03b42818e706f0685ba) --- airflow/triggers/external_task.py | 46 ++++++----- tests/triggers/test_external_task.py | 153 ++++++++++++++++++++++------------- 2 files changed, 122 insertions(+), 77 deletions(-) diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py index 5c7361a15b..a5de817c35 100644 --- a/airflow/triggers/external_task.py +++ b/airflow/triggers/external_task.py @@ -21,8 +21,10 @@ import typing from typing import Any from asgiref.sync import sync_to_async +from deprecated import deprecated from sqlalchemy import func +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models import DagRun, TaskInstance from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.sensor_helper import _get_count @@ -98,13 +100,7 @@ class WorkflowTrigger(BaseTrigger): """Check periodically tasks, task group or dag status.""" while True: if self.failed_states: - failed_count = _get_count( - self.execution_dates, - self.external_task_ids, - self.external_task_group_id, - self.external_dag_id, - self.failed_states, - ) + failed_count = await self._get_count(self.failed_states) if failed_count > 0: yield TriggerEvent({"status": "failed"}) return @@ -112,30 +108,38 @@ class WorkflowTrigger(BaseTrigger): yield TriggerEvent({"status": "success"}) return if self.skipped_states: - skipped_count = _get_count( - self.execution_dates, - self.external_task_ids, - self.external_task_group_id, - self.external_dag_id, - self.skipped_states, - ) + skipped_count = await self._get_count(self.skipped_states) if skipped_count > 0: yield TriggerEvent({"status": "skipped"}) return - allowed_count = _get_count( - self.execution_dates, - self.external_task_ids, - self.external_task_group_id, - self.external_dag_id, - self.allowed_states, - ) + allowed_count = await self._get_count(self.allowed_states) if allowed_count == len(self.execution_dates): yield TriggerEvent({"status": "success"}) return self.log.info("Sleeping for %s seconds", self.poke_interval) await asyncio.sleep(self.poke_interval) + @sync_to_async + def _get_count(self, states: typing.Iterable[str] | None) -> int: + """ + Get the count of records against dttm filter and states. Async wrapper for _get_count. + + :param states: task or dag states + :return The count of records. + """ + return _get_count( + dttm_filter=self.execution_dates, + external_task_ids=self.external_task_ids, + external_task_group_id=self.external_task_group_id, + external_dag_id=self.external_dag_id, + states=states, + ) + +@deprecated( + reason="TaskStateTrigger has been deprecated and will be removed in future.", + category=RemovedInAirflow3Warning, +) class TaskStateTrigger(BaseTrigger): """ Waits asynchronously for a task in a different DAG to complete for a specific logical date. diff --git a/tests/triggers/test_external_task.py b/tests/triggers/test_external_task.py index fe773049fa..8ce6d89a3a 100644 --- a/tests/triggers/test_external_task.py +++ b/tests/triggers/test_external_task.py @@ -18,11 +18,13 @@ from __future__ import annotations import asyncio import datetime +import time from unittest import mock import pytest from sqlalchemy.exc import SQLAlchemyError +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance @@ -41,11 +43,10 @@ class TestWorkflowTrigger: STATES = ["success", "fail"] @mock.patch("airflow.triggers.external_task._get_count") - @mock.patch("asyncio.sleep") @pytest.mark.asyncio - async def test_task_workflow_trigger_success(self, mock_sleep, mock_get_count): + async def test_task_workflow_trigger_success(self, mock_get_count): """check the db count get called correctly.""" - mock_get_count.return_value = 1 + mock_get_count.side_effect = mocked_get_count trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, execution_dates=[timezone.datetime(2022, 1, 1)], @@ -54,19 +55,29 @@ class TestWorkflowTrigger: poke_interval=0.2, ) - generator = trigger.run() - await generator.asend(None) + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) + await trigger_task + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() + result = trigger_task.result() + assert result.payload == {"status": "success"} mock_get_count.assert_called_once_with( - [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"] + dttm_filter=[timezone.datetime(2022, 1, 1)], + external_task_ids=["external_task_op"], + external_task_group_id=None, + external_dag_id="external_task", + states=["success", "fail"], ) # test that it returns after yielding with pytest.raises(StopAsyncIteration): - await generator.__anext__() + await gen.__anext__() @mock.patch("airflow.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_failed(self, mock_get_count): - mock_get_count.return_value = 1 + mock_get_count.side_effect = mocked_get_count trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, execution_dates=[timezone.datetime(2022, 1, 1)], @@ -77,13 +88,19 @@ class TestWorkflowTrigger: gen = trigger.run() trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) await trigger_task - assert trigger_task.done() is True + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "failed"} mock_get_count.assert_called_once_with( - [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"] + dttm_filter=[timezone.datetime(2022, 1, 1)], + external_task_ids=["external_task_op"], + external_task_group_id=None, + external_dag_id="external_task", + states=["success", "fail"], ) # test that it returns after yielding with pytest.raises(StopAsyncIteration): @@ -104,12 +121,16 @@ class TestWorkflowTrigger: gen = trigger.run() trigger_task = asyncio.create_task(gen.__anext__()) await trigger_task - assert trigger_task.done() is True + assert trigger_task.done() result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "success"} mock_get_count.assert_called_once_with( - [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"] + dttm_filter=[timezone.datetime(2022, 1, 1)], + external_task_ids=["external_task_op"], + external_task_group_id=None, + external_dag_id="external_task", + states=["success", "fail"], ) # test that it returns after yielding with pytest.raises(StopAsyncIteration): @@ -118,7 +139,7 @@ class TestWorkflowTrigger: @mock.patch("airflow.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_skipped(self, mock_get_count): - mock_get_count.return_value = 1 + mock_get_count.side_effect = mocked_get_count trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, execution_dates=[timezone.datetime(2022, 1, 1)], @@ -129,13 +150,19 @@ class TestWorkflowTrigger: gen = trigger.run() trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) await trigger_task - assert trigger_task.done() is True + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "skipped"} mock_get_count.assert_called_once_with( - [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"] + dttm_filter=[timezone.datetime(2022, 1, 1)], + external_task_ids=["external_task_op"], + external_task_group_id=None, + external_dag_id="external_task", + states=["success", "fail"], ) @mock.patch("airflow.triggers.external_task._get_count") @@ -153,7 +180,7 @@ class TestWorkflowTrigger: gen = trigger.run() trigger_task = asyncio.create_task(gen.__anext__()) await trigger_task - assert trigger_task.done() is True + assert trigger_task.done() result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "success"} @@ -222,14 +249,15 @@ class TestTaskStateTrigger: session.add(instance) session.commit() - trigger = TaskStateTrigger( - dag_id=dag.dag_id, - task_id=instance.task_id, - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) + with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): + trigger = TaskStateTrigger( + dag_id=dag.dag_id, + task_id=instance.task_id, + states=self.STATES, + execution_dates=[timezone.datetime(2022, 1, 1)], + poll_interval=0.2, + trigger_start_time=trigger_start_time, + ) task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) @@ -252,14 +280,15 @@ class TestTaskStateTrigger: trigger_start_time = utcnow() mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=61) - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) + with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): + trigger = TaskStateTrigger( + dag_id="dag1", + task_id="task1", + states=self.STATES, + execution_dates=[timezone.datetime(2022, 1, 1)], + poll_interval=0.2, + trigger_start_time=trigger_start_time, + ) trigger.count_running_dags = mock.AsyncMock() trigger.count_running_dags.return_value = 0 @@ -284,14 +313,15 @@ class TestTaskStateTrigger: trigger_start_time = utcnow() mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=20) - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) + with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): + trigger = TaskStateTrigger( + dag_id="dag1", + task_id="task1", + states=self.STATES, + execution_dates=[timezone.datetime(2022, 1, 1)], + poll_interval=0.2, + trigger_start_time=trigger_start_time, + ) trigger.count_running_dags = mock.AsyncMock() trigger.count_running_dags.return_value = 0 @@ -331,14 +361,15 @@ class TestTaskStateTrigger: trigger_start_time + datetime.timedelta(seconds=20), ] - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) + with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): + trigger = TaskStateTrigger( + dag_id="dag1", + task_id="task1", + states=self.STATES, + execution_dates=[timezone.datetime(2022, 1, 1)], + poll_interval=0.2, + trigger_start_time=trigger_start_time, + ) trigger.count_running_dags = mock.AsyncMock() trigger.count_running_dags.side_effect = [SQLAlchemyError] @@ -358,14 +389,15 @@ class TestTaskStateTrigger: and classpath. """ trigger_start_time = utcnow() - trigger = TaskStateTrigger( - dag_id=self.DAG_ID, - task_id=self.TASK_ID, - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=5, - trigger_start_time=trigger_start_time, - ) + with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): + trigger = TaskStateTrigger( + dag_id=self.DAG_ID, + task_id=self.TASK_ID, + states=self.STATES, + execution_dates=[timezone.datetime(2022, 1, 1)], + poll_interval=5, + trigger_start_time=trigger_start_time, + ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.triggers.external_task.TaskStateTrigger" assert kwargs == { @@ -438,3 +470,12 @@ class TestDagStateTrigger: "execution_dates": [timezone.datetime(2022, 1, 1)], "poll_interval": 5, } + + +def mocked_get_count(*args, **kwargs): + time.sleep(0.0001) + return 1 + + +async def fake_async_fun(): + await asyncio.sleep(0.00005)