This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ec60dd799d65bdb80b83893db7df215d98342dde Author: Ephraim Anierobi <[email protected]> AuthorDate: Fri Jan 14 09:55:15 2022 +0100 Handle stuck queued tasks in Celery for db backend(#19769) Move the state of stuck queued tasks in Celery to Scheduled so that the Scheduler can queue them again. Only applies to DatabaseBackend (cherry picked from commit 14ee831c7ad767e31a3aeccf3edbc519b3b8c923) --- airflow/config_templates/config.yml | 7 + airflow/config_templates/default_airflow.cfg | 3 + airflow/executors/celery_executor.py | 52 +++++++ tests/executors/test_celery_executor.py | 197 ++++++++++++++++++++++++--- 4 files changed, 242 insertions(+), 17 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 6941f03..e061568 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1663,6 +1663,13 @@ type: string example: ~ default: "False" + - name: stuck_queued_task_check_interval + description: | + How often to check for stuck queued task (in seconds) + version_added: 2.3.0 + type: integer + example: ~ + default: "300" - name: celery_broker_transport_options description: | This section is for specifying options which can be passed to the diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 6a5449b..4024922 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -830,6 +830,9 @@ task_publish_max_retries = 3 # Worker initialisation check to validate Metadata Database connection worker_precheck = False +# How often to check for stuck queued task (in seconds) +stuck_queued_task_check_interval = 300 + [celery_broker_transport_options] # This section is for specifying options which can be passed to the diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index f257b0c..8daced6 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -40,6 +40,7 @@ from celery.backends.database import DatabaseBackend, Task as TaskDb, session_cl from celery.result import AsyncResult from celery.signals import import_modules as celery_import_modules from setproctitle import setproctitle +from sqlalchemy.orm.session import Session import airflow.settings as settings from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG @@ -50,6 +51,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.timezone import utcnow @@ -231,6 +233,10 @@ class CeleryExecutor(BaseExecutor): self.task_adoption_timeout = datetime.timedelta( seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600) ) + self.stuck_tasks_last_check_time: int = time.time() + self.stuck_queued_task_check_interval = conf.getint( + 'celery', 'stuck_queued_task_check_interval', fallback=300 + ) self.task_publish_retries: Dict[TaskInstanceKey, int] = OrderedDict() self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3) @@ -335,6 +341,8 @@ class CeleryExecutor(BaseExecutor): if self.adopted_task_timeouts: self._check_for_stalled_adopted_tasks() + if time.time() - self.stuck_tasks_last_check_time > self.stuck_queued_task_check_interval: + self._clear_stuck_queued_tasks() def _check_for_stalled_adopted_tasks(self): """ @@ -375,6 +383,50 @@ class CeleryExecutor(BaseExecutor): for key in timedout_keys: self.change_state(key, State.FAILED) + @provide_session + def _clear_stuck_queued_tasks(self, session: Session = NEW_SESSION) -> None: + """ + Tasks can get stuck in queued state in DB while still not in + worker. This happens when the worker is autoscaled down and + the task is queued but has not been picked up by any worker prior to the scaling. + + In such situation, we update the task instance state to scheduled so that + it can be queued again. We chose to use task_adoption_timeout to decide when + a queued task is considered stuck and should be reschelduled. + """ + if not isinstance(app.backend, DatabaseBackend): + # We only want to do this for database backends where + # this case has been spotted + return + # We use this instead of using bulk_state_fetcher because we + # may not have the stuck task in self.tasks and we don't want + # to clear task in self.tasks too + session_ = app.backend.ResultSession() + task_cls = getattr(app.backend, "task_cls", TaskDb) + with session_cleanup(session_): + celery_task_ids = [t.task_id for t in session_.query(task_cls.task_id).all()] + self.log.debug("Checking for stuck queued tasks") + + max_allowed_time = utcnow() - self.task_adoption_timeout + + for task in session.query(TaskInstance).filter( + TaskInstance.state == State.QUEUED, TaskInstance.queued_dttm < max_allowed_time + ): + if task.key in self.queued_tasks or task.key in self.running: + continue + + if task.external_executor_id in celery_task_ids: + # The task is still running in the worker + continue + + self.log.info( + 'TaskInstance: %s found in queued state for more than %s seconds, rescheduling', + task, + self.task_adoption_timeout.total_seconds(), + ) + task.state = State.SCHEDULED + session.merge(task) + def debug_dump(self) -> None: """Called in response to SIGUSR2 by the scheduler""" super().debug_dump() diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index db63b18..5632f7d 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -17,10 +17,13 @@ # under the License. import contextlib import json +import logging import os import signal import sys +import time import unittest +from collections import namedtuple from datetime import datetime, timedelta from unittest import mock @@ -32,6 +35,7 @@ from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend from celery.backends.database import DatabaseBackend from celery.contrib.testing.worker import start_worker from celery.result import AsyncResult +from freezegun import freeze_time from kombu.asynchronous import set_event_loop from parameterized import parameterized @@ -94,12 +98,12 @@ def _prepare_app(broker_url=None, execute=None): set_event_loop(None) -class TestCeleryExecutor(unittest.TestCase): - def setUp(self) -> None: +class TestCeleryExecutor: + def setup_method(self) -> None: db.clear_db_runs() db.clear_db_jobs() - def tearDown(self) -> None: + def teardown_method(self) -> None: db.clear_db_runs() db.clear_db_jobs() @@ -196,10 +200,10 @@ class TestCeleryExecutor(unittest.TestCase): @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @pytest.mark.backend("mysql", "postgres") - def test_retry_on_error_sending_task(self): + def test_retry_on_error_sending_task(self, caplog): """Test that Airflow retries publishing tasks to Celery Broker at least 3 times""" - with _prepare_app(), self.assertLogs(celery_executor.log) as cm, mock.patch.object( + with _prepare_app(), caplog.at_level(logging.INFO), mock.patch.object( # Mock `with timeout()` to _instantly_ fail. celery_executor.timeout, "__enter__", @@ -227,28 +231,19 @@ class TestCeleryExecutor(unittest.TestCase): assert dict(executor.task_publish_retries) == {key: 2} assert 1 == len(executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} - assert ( - "INFO:airflow.executors.celery_executor.CeleryExecutor:" - f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output - ) + assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in caplog.text executor.heartbeat() assert dict(executor.task_publish_retries) == {key: 3} assert 1 == len(executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} - assert ( - "INFO:airflow.executors.celery_executor.CeleryExecutor:" - f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output - ) + assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in caplog.text executor.heartbeat() assert dict(executor.task_publish_retries) == {key: 4} assert 1 == len(executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} - assert ( - "INFO:airflow.executors.celery_executor.CeleryExecutor:" - f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output - ) + assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in caplog.text executor.heartbeat() assert dict(executor.task_publish_retries) == {} @@ -411,6 +406,174 @@ class TestCeleryExecutor(unittest.TestCase): assert executor.running == {key_2} assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout} + @pytest.mark.backend("mysql", "postgres") + @pytest.mark.parametrize( + "state, queued_dttm, executor_id", + [ + (State.SCHEDULED, timezone.utcnow() - timedelta(days=2), '231'), + (State.QUEUED, timezone.utcnow(), '231'), + (State.QUEUED, timezone.utcnow(), None), + ], + ) + def test_stuck_queued_tasks_are_cleared( + self, state, queued_dttm, executor_id, session, dag_maker, create_dummy_dag, create_task_instance + ): + """Test that clear_stuck_queued_tasks works""" + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = queued_dttm + ti.external_executor_id = executor_id + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor._clear_stuck_queued_tasks() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + assert ti.state == state + + @pytest.mark.backend("mysql", "postgres") + def test_task_in_queued_tasks_dict_are_not_cleared( + self, session, dag_maker, create_dummy_dag, create_task_instance + ): + """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.queued_tasks""" + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = timezone.utcnow() - timedelta(days=2) + ti.external_executor_id = '231' + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor.queued_tasks = {ti.key: AsyncResult("231")} + executor._clear_stuck_queued_tasks() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + assert executor.queued_tasks == {ti.key: AsyncResult("231")} + assert ti.state == State.QUEUED + + @pytest.mark.backend("mysql", "postgres") + def test_task_in_running_dict_are_not_cleared( + self, session, dag_maker, create_dummy_dag, create_task_instance + ): + """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.running""" + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = timezone.utcnow() - timedelta(days=2) + ti.external_executor_id = '231' + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor.running = {ti.key: AsyncResult("231")} + executor._clear_stuck_queued_tasks() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + assert executor.running == {ti.key: AsyncResult("231")} + assert ti.state == State.QUEUED + + @pytest.mark.backend("mysql", "postgres") + def test_only_database_result_backend_supports_clearing_queued_task( + self, session, dag_maker, create_dummy_dag, create_task_instance + ): + with _prepare_app(): + mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app) + with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend): + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = timezone.utcnow() - timedelta(days=2) + ti.external_executor_id = '231' + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor.tasks = {ti.key: AsyncResult("231")} + executor._clear_stuck_queued_tasks() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + # Not cleared + assert ti.state == State.QUEUED + assert executor.tasks == {ti.key: AsyncResult("231")} + + @mock.patch("celery.backends.database.DatabaseBackend.ResultSession") + @pytest.mark.backend("mysql", "postgres") + @freeze_time("2020-01-01") + @pytest.mark.parametrize( + "state", + [ + (State.SCHEDULED), + (State.QUEUED), + ], + ) + def test_the_check_interval_to_clear_stuck_queued_task_is_correct( + self, + mock_result_session, + state, + session, + dag_maker, + create_dummy_dag, + create_task_instance, + ): + with _prepare_app(): + mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://") + with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend): + mock_session = mock_backend.ResultSession.return_value + mock_session.query.return_value.all.return_value = [ + mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}}) + ] + if state == State.SCHEDULED: + last_check_time = time.time() - 302 # should clear ti state + else: + last_check_time = time.time() - 298 # should not clear ti state + + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = timezone.utcnow() - timedelta(days=2) + ti.external_executor_id = '231' + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor.tasks = {ti.key: AsyncResult("231")} + executor.stuck_tasks_last_check_time = last_check_time + executor.sync() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + assert ti.state == state + + @mock.patch("celery.backends.database.DatabaseBackend.ResultSession") + @pytest.mark.backend("mysql", "postgres") + @freeze_time("2020-01-01") + @pytest.mark.parametrize( + "task_id, state", + [ + ('231', State.QUEUED), + ('111', State.SCHEDULED), + ], + ) + def test_the_check_interval_to_clear_stuck_queued_task_is_correct_for_db_query( + self, + mock_result_session, + task_id, + state, + session, + dag_maker, + create_dummy_dag, + create_task_instance, + ): + """Here we test that task are not cleared if found in celery database""" + result_obj = namedtuple('Result', ['status', 'task_id']) + with _prepare_app(): + mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://") + with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend): + mock_session = mock_backend.ResultSession.return_value + mock_session.query.return_value.all.return_value = [result_obj("SUCCESS", task_id)] + + last_check_time = time.time() - 302 # should clear ti state + + ti = create_task_instance(state=State.QUEUED) + ti.queued_dttm = timezone.utcnow() - timedelta(days=2) + ti.external_executor_id = '231' + session.merge(ti) + session.flush() + executor = celery_executor.CeleryExecutor() + executor.tasks = {ti.key: AsyncResult("231")} + executor.stuck_tasks_last_check_time = last_check_time + executor.sync() + session.flush() + ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one() + assert ti.state == state + def test_operation_timeout_config(): assert celery_executor.OPERATION_TIMEOUT == 1
