This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 02cdb2cd153ca0e89f4bc2062d38a1af3bc1d2c5 Author: Daniel Standish <15932138+dstand...@users.noreply.github.com> AuthorDate: Thu Jan 6 15:16:02 2022 -0800 Fix duplicate trigger creation race condition (#20699) The process for queueing up a trigger, for execution by the TriggerRunner, is handled by the TriggerJob's `load_triggers` method. It fetches the triggers that should be running according to the database, checks if they are running and if not it adds them to `TriggerRunner.to_create`. The problem is tha there's a small window of time between the moment a trigger (upon termination) is purged from the `TriggerRunner.triggers` set, and the time that the database is updated to reflect t [...] To resolve this what we do here is, before adding a trigger to the `to_create` queue, instead of comparing against the "running" triggers, we compare against all triggers known to the TriggerRunner instance. When triggers move out of the `triggers` set they move into other data structures such as `events` and `failed_triggers` and `to_cancel`. So we union all of these and only create those triggers which the database indicates should exist _and_ which are know already being handled [...] (cherry picked from commit 16b8c476518ed76e3689966ec4b0b788be935410) --- airflow/jobs/triggerer_job.py | 12 +++- tests/jobs/test_triggerer_job.py | 136 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/triggerer_job.py b/airflow/jobs/triggerer_job.py index 25a4c79..dff0e0f 100644 --- a/airflow/jobs/triggerer_job.py +++ b/airflow/jobs/triggerer_job.py @@ -381,10 +381,16 @@ class TriggerRunner(threading.Thread, LoggingMixin): # line's execution, but we consider that safe, since there's a strict # add -> remove -> never again lifecycle this function is already # handling. - current_trigger_ids = set(self.triggers.keys()) + running_trigger_ids = set(self.triggers.keys()) + known_trigger_ids = ( + running_trigger_ids.union(x[0] for x in self.events) + .union(self.to_cancel) + .union(x[0] for x in self.to_create) + .union(self.failed_triggers) + ) # Work out the two difference sets - new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids) - cancel_trigger_ids = current_trigger_ids.difference(requested_trigger_ids) + new_trigger_ids = requested_trigger_ids - known_trigger_ids + cancel_trigger_ids = running_trigger_ids - requested_trigger_ids # Bulk-fetch new trigger records new_triggers = Trigger.bulk_fetch(new_trigger_ids) # Add in new triggers diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index 5adc91f..870116a 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -16,29 +16,53 @@ # specific language governing permissions and limitations # under the License. +import asyncio import datetime import sys import time +from threading import Thread import pytest -from airflow.jobs.triggerer_job import TriggererJob -from airflow.models import Trigger +from airflow.jobs.triggerer_job import TriggererJob, TriggerRunner +from airflow.models import DagModel, DagRun, TaskInstance, Trigger from airflow.operators.dummy import DummyOperator +from airflow.operators.python import PythonOperator from airflow.triggers.base import TriggerEvent from airflow.triggers.temporal import TimeDeltaTrigger from airflow.triggers.testing import FailureTrigger, SuccessTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State, TaskInstanceState -from tests.test_utils.db import clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_runs + + +class TimeDeltaTrigger_(TimeDeltaTrigger): + def __init__(self, delta, filename): + super().__init__(delta=delta) + self.filename = filename + self.delta = delta + + async def run(self): + with open(self.filename, 'at') as f: + f.write('hi\n') + async for event in super().run(): + yield event + + def serialize(self): + return ( + "tests.jobs.test_triggerer_job.TimeDeltaTrigger_", + {"delta": self.delta, "filename": self.filename}, + ) @pytest.fixture(autouse=True) def clean_database(): """Fixture that cleans the database before and after every test.""" clear_db_runs() + clear_db_dags() yield # Test runs here + clear_db_dags() clear_db_runs() @@ -160,6 +184,112 @@ def test_trigger_lifecycle(session): @pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6") +def test_trigger_create_race_condition_18392(session, tmp_path): + """ + This verifies the resolution of race condition documented in github issue #18392. + Triggers are queued for creation by TriggerJob.load_triggers. + There was a race condition where multiple triggers would be created unnecessarily. + What happens is the runner completes the trigger and purges from the "running" list. + Then job.load_triggers is called and it looks like the trigger is not running but should, + so it queues it again. + + The scenario is as follows: + 1. job.load_triggers (trigger now queued) + 2. runner.create_triggers (trigger now running) + 3. job.handle_events (trigger still appears running so state not updated in DB) + 4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set) + 5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again) + 6. runner.create_triggers (trigger created again) + + This test verifies that under this scenario only one trigger is created. + """ + path = tmp_path / 'test_trigger_bad_respawn.txt' + + class TriggerRunner_(TriggerRunner): + """We do some waiting for main thread looping""" + + async def wait_for_job_method_count(self, method, count): + for _ in range(30): + await asyncio.sleep(0.1) + if getattr(self, f'{method}_count', 0) >= count: + break + else: + pytest.fail(f"did not observe count {count} in job method {method}") + + async def create_triggers(self): + """ + On first run, wait for job.load_triggers to make sure they are queued + """ + if getattr(self, 'loop_count', 0) == 0: + await self.wait_for_job_method_count('load_triggers', 1) + await super().create_triggers() + self.loop_count = getattr(self, 'loop_count', 0) + 1 + + async def cleanup_finished_triggers(self): + """On loop 1, make sure that job.handle_events was already called""" + if self.loop_count == 1: + await self.wait_for_job_method_count('handle_events', 1) + await super().cleanup_finished_triggers() + + class TriggererJob_(TriggererJob): + """We do some waiting for runner thread looping (and track calls in job thread)""" + + def wait_for_runner_loop(self, runner_loop_count): + for _ in range(30): + time.sleep(0.1) + if getattr(self.runner, 'call_count', 0) >= runner_loop_count: + break + else: + pytest.fail("did not observe 2 loops in the runner thread") + + def load_triggers(self): + """On second run, make sure that runner has called create_triggers in its second loop""" + super().load_triggers() + self.runner.load_triggers_count = getattr(self.runner, 'load_triggers_count', 0) + 1 + if self.runner.load_triggers_count == 2: + self.wait_for_runner_loop(runner_loop_count=2) + + def handle_events(self): + super().handle_events() + self.runner.handle_events_count = getattr(self.runner, 'handle_events_count', 0) + 1 + + trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix()) + trigger_orm = Trigger.from_object(trigger) + trigger_orm.id = 1 + session.add(trigger_orm) + + dag = DagModel(dag_id='test-dag') + dag_run = DagRun(dag.dag_id, run_id='abc', run_type='none') + ti = TaskInstance(PythonOperator(task_id='dummy-task', python_callable=print), run_id=dag_run.run_id) + ti.dag_id = dag.dag_id + ti.trigger_id = 1 + session.add(dag) + session.add(dag_run) + session.add(ti) + + session.commit() + + job = TriggererJob_() + job.runner = TriggerRunner_() + thread = Thread(target=job._execute) + thread.start() + try: + for _ in range(40): + time.sleep(0.1) + # ready to evaluate after 2 loops + if getattr(job.runner, 'loop_count', 0) >= 2: + break + else: + pytest.fail("did not observe 2 loops in the runner thread") + finally: + job.runner.stop = True + job.runner.join() + thread.join() + instances = path.read_text().splitlines() + assert len(instances) == 1 + + +@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6") def test_trigger_from_dead_triggerer(session): """ Checks that the triggerer will correctly claim a Trigger that is assigned to a