This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a59a567c1dd99464c725e06b2d2472b37bcb600d Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Tue Jul 20 18:48:35 2021 +0100 Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889) This change adds pytest fixture to create dag and dagrun then use it on local task job tests Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> (cherry picked from commit 7c0d8a2f83cc6db25bdddcf6cecb6fb56f05f02f) --- tests/conftest.py | 50 +++++++++++ tests/jobs/test_local_task_job.py | 178 +++++++++++++------------------------- 2 files changed, 111 insertions(+), 117 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 55e1593..f2c5345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -425,3 +425,53 @@ def app(): from airflow.www import app return app.create_app(testing=True) + + +@pytest.fixture +def dag_maker(request): + from airflow.models import DAG + from airflow.utils import timezone + from airflow.utils.state import State + + DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + class DagFactory: + def __enter__(self): + self.dag.__enter__() + return self.dag + + def __exit__(self, type, value, traceback): + dag = self.dag + dag.__exit__(type, value, traceback) + if type is None: + dag.clear() + self.dag_run = dag.create_dagrun( + run_id=self.kwargs.get("run_id", "test"), + state=self.kwargs.get('state', State.RUNNING), + execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']), + start_date=self.kwargs['start_date'], + ) + + def __call__(self, dag_id='test_dag', **kwargs): + self.kwargs = kwargs + if "start_date" not in kwargs: + if hasattr(request.module, 'DEFAULT_DATE'): + kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE') + else: + kwargs['start_date'] = DEFAULT_DATE + dagrun_fields_not_in_dag = [ + 'state', + 'execution_date', + 'run_type', + 'queued_at', + "run_id", + "creating_job_id", + "external_trigger", + "last_scheduling_decision", + "dag_hash", + ] + kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag} + self.dag = DAG(dag_id, **kwargs) + return self + + return DagFactory() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 11e9adf..d9f1398 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -20,7 +20,6 @@ import multiprocessing import os import signal import time -import unittest import uuid from multiprocessing import Lock, Value from unittest import mock @@ -57,21 +56,30 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER'] -class TestLocalTaskJob(unittest.TestCase): - def setUp(self): - db.clear_db_dags() - db.clear_db_jobs() - db.clear_db_runs() - db.clear_db_task_fail() - patcher = patch('airflow.jobs.base_job.sleep') - self.addCleanup(patcher.stop) - self.mock_base_job_sleep = patcher.start() +@pytest.fixture +def clear_db(): + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + yield + + +@pytest.fixture(scope='class') +def clear_db_class(): + yield + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + - def tearDown(self) -> None: - db.clear_db_dags() - db.clear_db_jobs() - db.clear_db_runs() - db.clear_db_task_fail() +@pytest.mark.usefixtures('clear_db_class', 'clear_db') +class TestLocalTaskJob: + @pytest.fixture(autouse=True) + def set_instance_attrs(self): + with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep: + yield def validate_ti_states(self, dag_run, ti_state_mapping, error_message): for task_id, expected_state in ti_state_mapping.items(): @@ -79,23 +87,19 @@ class TestLocalTaskJob(unittest.TestCase): task_instance.refresh_from_db() assert task_instance.state == expected_state, error_message - def test_localtaskjob_essential_attr(self): + def test_localtaskjob_essential_attr(self, dag_maker): """ Check whether essential attributes of LocalTaskJob can be assigned with proper values without intervention """ - dag = DAG( + with dag_maker( 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} - ) - - with dag: + ): op1 = DummyOperator(task_id='op1') - dag.clear() - dr = dag.create_dagrun( - run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE - ) + dr = dag_maker.dag_run + ti = dr.get_task_instance(task_id=op1.task_id) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -108,21 +112,12 @@ class TestLocalTaskJob(unittest.TestCase): check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr] assert all(check_result_2) - def test_localtaskjob_heartbeat(self): + def test_localtaskjob_heartbeat(self, dag_maker): session = settings.Session() - dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - - with dag: + with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1') - dag.clear() - dr = dag.create_dagrun( - run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) + dr = dag_maker.dag_run ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" @@ -150,22 +145,11 @@ class TestLocalTaskJob(unittest.TestCase): job1.heartbeat_callback() @mock.patch('airflow.jobs.local_task_job.psutil') - def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock): + def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, dag_maker): session = settings.Session() - dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - - with dag: + with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1', run_as_user='myuser') - - dag.clear() - dr = dag.create_dagrun( - run_id="test", - state=State.SUCCESS, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) - + dr = dag_maker.dag_run ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.pid = 2 @@ -248,7 +232,8 @@ class TestLocalTaskJob(unittest.TestCase): Test that task heartbeat will sleep when it fails fast """ self.mock_base_job_sleep.side_effect = time.sleep - + dag_id = 'test_heartbeat_failed_fast' + task_id = 'test_heartbeat_failed_fast_op' with create_session() as session: dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, @@ -266,6 +251,7 @@ class TestLocalTaskJob(unittest.TestCase): start_date=DEFAULT_DATE, session=session, ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING @@ -331,6 +317,7 @@ class TestLocalTaskJob(unittest.TestCase): assert State.SUCCESS == ti.state def test_localtaskjob_double_trigger(self): + dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, @@ -348,6 +335,7 @@ class TestLocalTaskJob(unittest.TestCase): start_date=DEFAULT_DATE, session=session, ) + ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() @@ -418,7 +406,7 @@ class TestLocalTaskJob(unittest.TestCase): assert time_end - time_start < job1.heartrate session.close() - def test_mark_failure_on_failure_callback(self): + def test_mark_failure_on_failure_callback(self, dag_maker): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback @@ -447,22 +435,12 @@ class TestLocalTaskJob(unittest.TestCase): with task_terminated_externally.get_lock(): task_terminated_externally.value = 0 - with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag: + with dag_maker("test_mark_failure", start_date=DEFAULT_DATE): task = PythonOperator( task_id='test_state_succeeded1', python_callable=task_function, on_failure_callback=check_failure, ) - - dag.clear() - with create_session() as session: - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -479,7 +457,7 @@ class TestLocalTaskJob(unittest.TestCase): @patch('airflow.utils.process_utils.subprocess.check_call') @patch.object(StandardTaskRunner, 'return_code') - def test_failure_callback_only_called_once(self, mock_return_code, _check_call): + def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker): """ Test that ensures that when a task exits with failure by itself, failure callback is only called once @@ -498,22 +476,11 @@ class TestLocalTaskJob(unittest.TestCase): def task_function(ti): raise AirflowFailException() - dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE) - task = PythonOperator( - task_id='test_exit_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, - dag=dag, - ) - - dag.clear() - with create_session() as session: - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, + with dag_maker("test_failure_callback_race"): + task = PythonOperator( + task_id='test_exit_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -544,7 +511,7 @@ class TestLocalTaskJob(unittest.TestCase): assert failure_callback_called.value == 1 @pytest.mark.quarantined - def test_mark_success_on_success_callback(self): + def test_mark_success_on_success_callback(self, dag_maker): """ Test that ensures that where a task is marked success in the UI on_success_callback gets executed @@ -560,8 +527,6 @@ class TestLocalTaskJob(unittest.TestCase): success_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_success' - dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - def task_function(ti): time.sleep(60) @@ -569,23 +534,15 @@ class TestLocalTaskJob(unittest.TestCase): with shared_mem_lock: task_terminated_externally.value = 0 - task = PythonOperator( - task_id='test_state_succeeded1', - python_callable=task_function, - on_success_callback=success_callback, - dag=dag, - ) + with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): + task = PythonOperator( + task_id='test_state_succeeded1', + python_callable=task_function, + on_success_callback=success_callback, + ) session = settings.Session() - dag.clear() - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -616,7 +573,7 @@ class TestLocalTaskJob(unittest.TestCase): (signal.SIGKILL,), ] ) - def test_process_kill_calls_on_failure_callback(self, signal_type): + def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker): """ Test that ensures that when a task is killed with sigterm or sigkill on_failure_callback gets executed @@ -632,8 +589,6 @@ class TestLocalTaskJob(unittest.TestCase): failure_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_failure' - dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) - def task_function(ti): time.sleep(60) @@ -641,23 +596,12 @@ class TestLocalTaskJob(unittest.TestCase): with shared_mem_lock: task_terminated_externally.value = 0 - task = PythonOperator( - task_id='test_on_failure', - python_callable=task_function, - on_failure_callback=failure_callback, - dag=dag, - ) - - session = settings.Session() - - dag.clear() - dag.create_dagrun( - run_id="test", - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) + with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())