This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a3879b2e1bc6c8b30443fd41f02ff580a3001966 Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Tue Jul 20 18:48:35 2021 +0100 Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889) This change adds pytest fixture to create dag and dagrun then use it on local task job tests Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> (cherry picked from commit 7c0d8a2f83cc6db25bdddcf6cecb6fb56f05f02f) --- tests/conftest.py | 50 +++++++++ tests/jobs/test_local_task_job.py | 215 +++++++++++++++++--------------------- 2 files changed, 148 insertions(+), 117 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 55e1593..f2c5345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -425,3 +425,53 @@ def app(): from airflow.www import app return app.create_app(testing=True) + + +@pytest.fixture +def dag_maker(request): + from airflow.models import DAG + from airflow.utils import timezone + from airflow.utils.state import State + + DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + class DagFactory: + def __enter__(self): + self.dag.__enter__() + return self.dag + + def __exit__(self, type, value, traceback): + dag = self.dag + dag.__exit__(type, value, traceback) + if type is None: + dag.clear() + self.dag_run = dag.create_dagrun( + run_id=self.kwargs.get("run_id", "test"), + state=self.kwargs.get('state', State.RUNNING), + execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']), + start_date=self.kwargs['start_date'], + ) + + def __call__(self, dag_id='test_dag', **kwargs): + self.kwargs = kwargs + if "start_date" not in kwargs: + if hasattr(request.module, 'DEFAULT_DATE'): + kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE') + else: + kwargs['start_date'] = DEFAULT_DATE + dagrun_fields_not_in_dag = [ + 'state', + 'execution_date', + 'run_type', + 'queued_at', + "run_id", + "creating_job_id", + "external_trigger", + "last_scheduling_decision", + "dag_hash", + ] + kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag} + self.dag = DAG(dag_id, **kwargs) + return self + + return DagFactory() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 9d80647..3475ef1 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -20,7 +20,6 @@ import multiprocessing import os import signal import time -import unittest import uuid from multiprocessing import Lock, Value from unittest import mock @@ -55,21 +54,30 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER'] -class TestLocalTaskJob(unittest.TestCase): - def setUp(self): - db.clear_db_dags() - db.clear_db_jobs() - db.clear_db_runs() - db.clear_db_task_fail() - patcher = patch('airflow.jobs.base_job.sleep') - self.addCleanup(patcher.stop) - self.mock_base_job_sleep = patcher.start() +@pytest.fixture +def clear_db(): + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + yield + + +@pytest.fixture(scope='class') +def clear_db_class(): + yield + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + - def tearDown(self) -> None: - db.clear_db_dags() - db.clear_db_jobs() - db.clear_db_runs() - db.clear_db_task_fail() +@pytest.mark.usefixtures('clear_db_class', 'clear_db') +class TestLocalTaskJob: + @pytest.fixture(autouse=True) + def set_instance_attrs(self): + with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep: + yield def validate_ti_states(self, dag_run, ti_state_mapping, error_message): for task_id, expected_state in ti_state_mapping.items(): @@ -77,23 +85,19 @@ class TestLocalTaskJob(unittest.TestCase): task_instance.refresh_from_db() assert task_instance.state == expected_state, error_message - def test_localtaskjob_essential_attr(self): + def test_localtaskjob_essential_attr(self, dag_maker): """ Check whether essential attributes of LocalTaskJob can be assigned with proper values without intervention """ - dag = DAG( + with dag_maker( 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} - ) - - with dag: + ): op1 = DummyOperator(task_id='op1') - dag.clear() - dr = dag.create_dagrun( - run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE - ) + dr = dag_maker.dag_run + ti = dr.get_task_instance(task_id=op1.task_id) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -106,21 +110,12 @@ class TestLocalTaskJob(unittest.TestCase): check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr] assert all(check_result_2) - def test_localtaskjob_heartbeat(self): + def test_localtaskjob_heartbeat(self, dag_maker): session = settings.Session() - dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - - with dag: + with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1') - dag.clear() - dr = dag.create_dagrun( - run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) + dr = dag_maker.dag_run ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" @@ -148,22 +143,11 @@ class TestLocalTaskJob(unittest.TestCase): job1.heartbeat_callback() @mock.patch('airflow.jobs.local_task_job.psutil') - def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock): + def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, dag_maker): session = settings.Session() - dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - - with dag: + with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1', run_as_user='myuser') - - dag.clear() - dr = dag.create_dagrun( - run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) - + dr = dag_maker.dag_run ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.pid = 2 @@ -246,7 +230,8 @@ class TestLocalTaskJob(unittest.TestCase): Test that task heartbeat will sleep when it fails fast """ self.mock_base_job_sleep.side_effect = time.sleep - + dag_id = 'test_heartbeat_failed_fast' + task_id = 'test_heartbeat_failed_fast_op' with create_session() as session: dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, @@ -264,6 +249,7 @@ class TestLocalTaskJob(unittest.TestCase): start_date=DEFAULT_DATE, session=session, ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING @@ -329,6 +315,7 @@ class TestLocalTaskJob(unittest.TestCase): assert State.SUCCESS == ti.state def test_localtaskjob_double_trigger(self): + dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, @@ -346,6 +333,7 @@ class TestLocalTaskJob(unittest.TestCase): start_date=DEFAULT_DATE, session=session, ) + ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() @@ -416,7 +404,7 @@ class TestLocalTaskJob(unittest.TestCase): assert time_end - time_start < job1.heartrate session.close() - def test_mark_failure_on_failure_callback(self): + def test_mark_failure_on_failure_callback(self, dag_maker): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback @@ -445,22 +433,12 @@ class TestLocalTaskJob(unittest.TestCase): with task_terminated_externally.get_lock(): task_terminated_externally.value = 0 - with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag: + with dag_maker("test_mark_failure", start_date=DEFAULT_DATE): task = PythonOperator( task_id='test_state_succeeded1', python_callable=task_function, on_failure_callback=check_failure, ) - - dag.clear() - with create_session() as session: - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -477,7 +455,7 @@ class TestLocalTaskJob(unittest.TestCase): @patch('airflow.utils.process_utils.subprocess.check_call') @patch.object(StandardTaskRunner, 'return_code') - def test_failure_callback_only_called_once(self, mock_return_code, _check_call): + def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker): """ Test that ensures that when a task exits with failure by itself, failure callback is only called once @@ -496,22 +474,11 @@ class TestLocalTaskJob(unittest.TestCase): def task_function(ti): raise AirflowFailException() - dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE) - task = PythonOperator( - task_id='test_exit_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, - dag=dag, - ) - - dag.clear() - with create_session() as session: - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, + with dag_maker("test_failure_callback_race"): + task = PythonOperator( + task_id='test_exit_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -542,7 +509,7 @@ class TestLocalTaskJob(unittest.TestCase): assert failure_callback_called.value == 1 @pytest.mark.quarantined - def test_mark_success_on_success_callback(self): + def test_mark_success_on_success_callback(self, dag_maker): """ Test that ensures that where a task is marked success in the UI on_success_callback gets executed @@ -558,8 +525,6 @@ class TestLocalTaskJob(unittest.TestCase): success_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_success' - dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - def task_function(ti): time.sleep(60) @@ -567,23 +532,15 @@ class TestLocalTaskJob(unittest.TestCase): with shared_mem_lock: task_terminated_externally.value = 0 - task = PythonOperator( - task_id='test_state_succeeded1', - python_callable=task_function, - on_success_callback=success_callback, - dag=dag, - ) + with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): + task = PythonOperator( + task_id='test_state_succeeded1', + python_callable=task_function, + on_success_callback=success_callback, + ) session = settings.Session() - dag.clear() - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -614,7 +571,7 @@ class TestLocalTaskJob(unittest.TestCase): (signal.SIGKILL,), ] ) - def test_process_kill_calls_on_failure_callback(self, signal_type): + def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker): """ Test that ensures that when a task is killed with sigterm or sigkill on_failure_callback gets executed @@ -630,8 +587,6 @@ class TestLocalTaskJob(unittest.TestCase): failure_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_failure' - dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - def task_function(ti): time.sleep(60) @@ -639,23 +594,12 @@ class TestLocalTaskJob(unittest.TestCase): with shared_mem_lock: task_terminated_externally.value = 0 - task = PythonOperator( - task_id='test_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, - dag=dag, - ) - - session = settings.Session() - - dag.clear() - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) + with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -780,6 +724,43 @@ class TestLocalTaskJob(unittest.TestCase): if scheduler_job.processor_agent: scheduler_job.processor_agent.end() + def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self): + """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" + dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) + op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True) + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + is_paused=True, + ) + session.add(orm_dag) + session.flush() + # Write Dag to DB + dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) + dagbag.bag_dag(dag, root_dag=dag) + dagbag.sync_to_db() + + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + assert dr.state == State.RUNNING + ti = TaskInstance(op1, dr.execution_date) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.run() + session.add(dr) + session.refresh(dr) + assert dr.state == State.SUCCESS + @pytest.fixture() def clean_db_helper():