This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6ab83b5ae83691b0ef5a8fc247a68b51c5f27a05 Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Wed Jul 28 15:57:35 2021 +0100 Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301) Currently, tasks are not retried when they receive SIGKILL or SIGTERM even if the task has retry. This change fixes it and added test for both SIGTERM and SIGKILL so we don't experience regression Also, SIGTERM sets the task as failed and raises AirflowException which heartbeat sometimes see as externally set to fail and not call failure_callbacks. This commit also fixes this by calling handle_task_exit when a task gets SIGTERM Co-authored-by: Ash Berlin-Taylor <ash_git...@firemirror.com> (cherry picked from commit 4e2a94c6d1bde5ddf2aa0251190c318ac22f3b17) --- airflow/jobs/local_task_job.py | 24 +++--- tests/jobs/test_local_task_job.py | 166 +++++++++++++++++++++++++++++++++----- tests/models/test_taskinstance.py | 32 ++++++++ 3 files changed, 189 insertions(+), 33 deletions(-) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 9249d21..b650ab4 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -78,12 +78,9 @@ class LocalTaskJob(BaseJob): def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") - self.on_kill() - self.task_instance.refresh_from_db() - if self.task_instance.state not in State.finished: - self.task_instance.set_state(State.FAILED) - self.task_instance._run_finished_callback(error="task received sigterm") - raise AirflowException("LocalTaskJob received SIGTERM signal") + self.task_runner.terminate() + self.handle_task_exit(128 + signum) + return signal.signal(signal.SIGTERM, signal_handler) @@ -148,16 +145,19 @@ class LocalTaskJob(BaseJob): self.on_kill() def handle_task_exit(self, return_code: int) -> None: - """Handle case where self.task_runner exits by itself""" + """Handle case where self.task_runner exits by itself or is externally killed""" + # Without setting this, heartbeat may get us + self.terminating = True self.log.info("Task exited with return code %s", return_code) self.task_instance.refresh_from_db() - # task exited by itself, so we need to check for error file + + if self.task_instance.state == State.RUNNING: + # This is for a case where the task received a SIGKILL + # while running or the task runner received a sigterm + self.task_instance.handle_failure(error=None) + # We need to check for error file # in case it failed due to runtime exception/error error = None - if self.task_instance.state == State.RUNNING: - # This is for a case where the task received a sigkill - # while running - self.task_instance.set_state(State.FAILED) if self.task_instance.state != State.SUCCESS: error = self.task_runner.deserialize_run_error() self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index d9f1398..94f894d 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -21,6 +21,7 @@ import os import signal import time import uuid +from datetime import timedelta from multiprocessing import Lock, Value from unittest import mock from unittest.mock import patch @@ -272,7 +273,6 @@ class TestLocalTaskJob: delta = (time2 - time1).total_seconds() assert abs(delta - job.heartrate) < 0.5 - @pytest.mark.quarantined def test_mark_success_no_kill(self): """ Test that ensures that mark_success in the UI doesn't cause @@ -300,7 +300,6 @@ class TestLocalTaskJob: job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) process = multiprocessing.Process(target=job1.run) process.start() - ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break @@ -510,7 +509,6 @@ class TestLocalTaskJob: assert ti.state == State.FAILED # task exits with failure state assert failure_callback_called.value == 1 - @pytest.mark.quarantined def test_mark_success_on_success_callback(self, dag_maker): """ Test that ensures that where a task is marked success in the UI @@ -567,15 +565,9 @@ class TestLocalTaskJob: assert task_terminated_externally.value == 1 assert not process.is_alive() - @parameterized.expand( - [ - (signal.SIGTERM,), - (signal.SIGKILL,), - ] - ) - def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker): + def test_task_sigkill_calls_on_failure_callback(self, dag_maker): """ - Test that ensures that when a task is killed with sigterm or sigkill + Test that ensures that when a task is killed with sigkill on_failure_callback gets executed """ # use shared memory value so we can properly track value change even if @@ -587,10 +579,50 @@ class TestLocalTaskJob: def failure_callback(context): with shared_mem_lock: failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_mark_failure' + assert context['dag_run'].dag_id == 'test_send_sigkill' def task_function(ti): + os.kill(os.getpid(), signal.SIGKILL) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker(dag_id='test_send_sigkill'): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, + ) + + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + time.sleep(0.3) + process.join(timeout=10) + assert failure_callback_called.value == 1 + assert task_terminated_externally.value == 1 + assert not process.is_alive() + + def test_process_sigterm_calls_on_failure_callback(self, dag_maker): + """ + Test that ensures that when a task runner is killed with sigterm + on_failure_callback gets executed + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + failure_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + def failure_callback(context): + with shared_mem_lock: + failure_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure' + + def task_function(ti): time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed with shared_mem_lock: @@ -605,20 +637,16 @@ class TestLocalTaskJob: ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - job1.task_runner = StandardTaskRunner(job1) - settings.engine.dispose() process = multiprocessing.Process(target=job1.run) process.start() - - for _ in range(0, 20): + for _ in range(0, 25): ti.refresh_from_db() - if ti.state == State.RUNNING and ti.pid is not None: + if ti.state == State.RUNNING: break time.sleep(0.2) - assert ti.pid is not None - assert ti.state == State.RUNNING - os.kill(ti.pid, signal_type) + os.kill(process.pid, signal.SIGTERM) + ti.refresh_from_db() process.join(timeout=10) assert failure_callback_called.value == 1 assert task_terminated_externally.value == 1 @@ -726,6 +754,102 @@ class TestLocalTaskJob: if scheduler_job.processor_agent: scheduler_job.processor_agent.end() + def test_task_sigkill_works_with_retries(self, dag_maker): + """ + Test that ensures that tasks are retried when they receive sigkill + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + retry_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + + def retry_callback(context): + with shared_mem_lock: + retry_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure_2' + + def task_function(ti): + os.kill(os.getpid(), signal.SIGKILL) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker( + dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} + ): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=timedelta(seconds=2), + on_retry_callback=retry_callback, + ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.start() + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + time.sleep(0.4) + process.join() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + assert retry_callback_called.value == 1 + assert task_terminated_externally.value == 1 + + def test_process_sigterm_works_with_retries(self, dag_maker): + """ + Test that ensures that task runner sets tasks to retry when they(task runner) + receive sigterm + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + retry_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + + def retry_callback(context): + with shared_mem_lock: + retry_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure_2' + + def task_function(ti): + time.sleep(60) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker(dag_id='test_mark_failure_2'): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=timedelta(seconds=2), + on_retry_callback=retry_callback, + ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.start() + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + for _ in range(0, 25): + ti.refresh_from_db() + if ti.state == State.RUNNING and ti.pid is not None: + break + time.sleep(0.2) + os.kill(process.pid, signal.SIGTERM) + process.join() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + assert retry_callback_called.value == 1 + assert task_terminated_externally.value == 1 + def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self): """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) @@ -788,5 +912,5 @@ class TestLocalTaskJobPerformance: mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) - with assert_queries_count(16): + with assert_queries_count(18): job.run() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index c1882e1..db23271 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -18,6 +18,7 @@ import datetime import os +import signal import time import unittest import urllib @@ -522,6 +523,37 @@ class TestTaskInstance(unittest.TestCase): ti.run() assert State.SKIPPED == ti.state + def test_task_sigterm_works_with_retries(self): + """ + Test that ensures that tasks are retried when they receive sigterm + """ + dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + + def task_function(ti): + # pylint: disable=unused-argument + os.kill(ti.pid, signal.SIGTERM) + + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=datetime.timedelta(seconds=2), + dag=dag, + ) + + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + ) + ti = TI(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + with self.assertRaises(AirflowException): + ti.run() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + def test_retry_delay(self): """ Test that retry delays are respected