uranusjr commented on code in PR #38666:
URL: https://github.com/apache/airflow/pull/38666#discussion_r1547052295


##########
airflow/jobs/triggerer_job_runner.py:
##########
@@ -672,6 +672,17 @@ def update_triggers(self, requested_trigger_ids: set[int]):
                 self.failed_triggers.append((new_id, e))
                 continue
 
+            # If new_trigger_orm.task_instance is None, this means the 
TaskInstance
+            # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
+            # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
+            # in a High-Availability setup.
+            if new_trigger_orm.task_instance is None:
+                self.log.warning((
+                    "TaskInstance for Trigger ID %s is None. It was likely 
updated by another trigger job. "
+                    "Skipping trigger instantiation."
+                ), new_id)

Review Comment:
   Does this warrant a warning? There’s nothing the user can do, and this is 
kind of a normal thing to happen judging from the original issue.



##########
tests/jobs/test_triggerer_job.py:
##########
@@ -309,6 +309,140 @@ def test_update_trigger_with_triggerer_argument_change(
         assert "got an unexpected keyword argument 'not_exists_arg'" in 
caplog.text
 
 
+def test_trigger_create_race_condition_38599(session, tmp_path):
+    """
+    This verifies the resolution of race condition documented in github issue 
#38599.
+    More details in the issue description.
+
+    The race condition may occur in the following scenario:
+        1. TaskInstance TI1 defers itself, which creates Trigger T1, which 
holds a
+            reference to TI1.
+        2. T1 gets picked up by TriggererJobRunner TJR1 and starts running T1.
+        3. TJR1 misses a heartbeat, most likely due to high host load causing 
delays in
+            each TriggererJobRunner._run_trigger_loop loop.
+        4. A second TriggererJobRunner TJR2 notices that T1 has missed its 
heartbeat,
+            so it starts the process of picking up any Triggers that TJR1 may 
have had,
+            including T1.
+        5. Before TJR2 starts executing T1, TJR1 finishes execution of T1 and 
cleans it
+            up by clearing the trigger_id of TI1.
+        6. TJR2 tries to execute T1, but it crashes (with the above error) 
while trying to
+            look up TI1 (because T1 no longer has a TaskInstance linked to it).
+    """
+    path = tmp_path / "test_trigger_create_after_completion.txt"
+    trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), 
filename=path.as_posix())
+    trigger_orm = Trigger.from_object(trigger)
+    trigger_orm.id = 1
+    session.add(trigger_orm)
+
+    dag = DagModel(dag_id="test-dag")
+    dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
+    ti = TaskInstance(
+        PythonOperator(task_id="dummy-task", python_callable=print),
+        run_id=dag_run.run_id,
+        state=TaskInstanceState.DEFERRED,
+    )
+    ti.dag_id = dag.dag_id
+    ti.trigger_id = 1
+    session.add(dag)
+    session.add(dag_run)
+    session.add(ti)
+
+    job1 = Job()
+    job2 = Job()
+    session.add(job1)
+    session.add(job2)
+
+    session.commit()
+
+    class TriggerRunnerWithCreateCount_(TriggerRunner):
+        async def create_triggers(self):
+            num_triggers_to_create = len(self.to_create)
+            await super().create_triggers()
+            self.trigger_creation_count = getattr(self, 
"trigger_creation_count", 0) + num_triggers_to_create
+
+    class TriggerRunnerWithUpdateDelay_(TriggerRunnerWithCreateCount_):
+        """TriggerRunner with a 5 second delay added at the beginning of 
update_triggers
+        to increase the window that the race condition may occur.
+        """
+
+        def update_triggers(self, *args, **kwargs):
+            # Delay calling update_triggers to increase the window of 
opportunity
+            time.sleep(5)
+            super().update_triggers(*args, **kwargs)
+
+        async def create_triggers(self):
+            await super().create_triggers()
+            self.create_triggers_count = getattr(self, 
"create_triggers_count", 0) + 1
+
+    class TriggererJobRunner_(TriggererJobRunner):
+        """TriggererJobRunner whose handle_events blocks until there is an 
event."""
+
+        def load_triggers(self):
+            super().load_triggers()
+            self.load_triggers_count = getattr(self, "load_triggers_count", 0) 
+ 1
+
+        def handle_events(self):
+            # Wait for event during the first loop
+            while not self.trigger_runner.events and getattr(self, 
"handle_events_count", 0) == 0:
+                time.sleep(0.1)
+            super().handle_events()
+            self.handle_events_count = getattr(self, "handle_events_count", 0) 
+ 1
+            # Prevent Trigger.clean_unused() from deleting the trigger
+            time.sleep(5)

Review Comment:
   Can this be shorter? Taking five seconds to run this is much too long. 
Ideally all the 0.1 sleeps should be shorter as well.



##########
tests/jobs/test_triggerer_job.py:
##########
@@ -309,6 +309,140 @@ def test_update_trigger_with_triggerer_argument_change(
         assert "got an unexpected keyword argument 'not_exists_arg'" in 
caplog.text
 
 
+def test_trigger_create_race_condition_38599(session, tmp_path):
+    """
+    This verifies the resolution of race condition documented in github issue 
#38599.
+    More details in the issue description.
+
+    The race condition may occur in the following scenario:
+        1. TaskInstance TI1 defers itself, which creates Trigger T1, which 
holds a
+            reference to TI1.
+        2. T1 gets picked up by TriggererJobRunner TJR1 and starts running T1.
+        3. TJR1 misses a heartbeat, most likely due to high host load causing 
delays in
+            each TriggererJobRunner._run_trigger_loop loop.
+        4. A second TriggererJobRunner TJR2 notices that T1 has missed its 
heartbeat,
+            so it starts the process of picking up any Triggers that TJR1 may 
have had,
+            including T1.
+        5. Before TJR2 starts executing T1, TJR1 finishes execution of T1 and 
cleans it
+            up by clearing the trigger_id of TI1.
+        6. TJR2 tries to execute T1, but it crashes (with the above error) 
while trying to
+            look up TI1 (because T1 no longer has a TaskInstance linked to it).
+    """
+    path = tmp_path / "test_trigger_create_after_completion.txt"
+    trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), 
filename=path.as_posix())
+    trigger_orm = Trigger.from_object(trigger)
+    trigger_orm.id = 1
+    session.add(trigger_orm)
+
+    dag = DagModel(dag_id="test-dag")
+    dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
+    ti = TaskInstance(
+        PythonOperator(task_id="dummy-task", python_callable=print),
+        run_id=dag_run.run_id,
+        state=TaskInstanceState.DEFERRED,
+    )
+    ti.dag_id = dag.dag_id
+    ti.trigger_id = 1
+    session.add(dag)
+    session.add(dag_run)
+    session.add(ti)
+
+    job1 = Job()
+    job2 = Job()
+    session.add(job1)
+    session.add(job2)
+
+    session.commit()
+
+    class TriggerRunnerWithCreateCount_(TriggerRunner):
+        async def create_triggers(self):
+            num_triggers_to_create = len(self.to_create)
+            await super().create_triggers()
+            self.trigger_creation_count = getattr(self, 
"trigger_creation_count", 0) + num_triggers_to_create
+
+    class TriggerRunnerWithUpdateDelay_(TriggerRunnerWithCreateCount_):
+        """TriggerRunner with a 5 second delay added at the beginning of 
update_triggers
+        to increase the window that the race condition may occur.
+        """
+
+        def update_triggers(self, *args, **kwargs):
+            # Delay calling update_triggers to increase the window of 
opportunity
+            time.sleep(5)
+            super().update_triggers(*args, **kwargs)
+
+        async def create_triggers(self):
+            await super().create_triggers()
+            self.create_triggers_count = getattr(self, 
"create_triggers_count", 0) + 1
+
+    class TriggererJobRunner_(TriggererJobRunner):
+        """TriggererJobRunner whose handle_events blocks until there is an 
event."""
+
+        def load_triggers(self):
+            super().load_triggers()
+            self.load_triggers_count = getattr(self, "load_triggers_count", 0) 
+ 1
+
+        def handle_events(self):
+            # Wait for event during the first loop
+            while not self.trigger_runner.events and getattr(self, 
"handle_events_count", 0) == 0:
+                time.sleep(0.1)
+            super().handle_events()
+            self.handle_events_count = getattr(self, "handle_events_count", 0) 
+ 1
+            # Prevent Trigger.clean_unused() from deleting the trigger
+            time.sleep(5)
+
+    # Start first TriggererJobRunner immediately.
+    # This TriggererJobRunner will immediately load the trigger and start 
running it.
+    # Once the trigger is finished, it will, however, stall after the first 
loop in handle_events,
+    # preventing the trigger from being cleaned up. This simulates what may 
happen during high load.
+    job_runner1 = TriggererJobRunner_(job1)
+    job_runner1.trigger_runner = TriggerRunnerWithCreateCount_()
+    thread1 = Thread(target=job_runner1._execute)
+    thread1.start()
+
+    # Simulate a missed heartbeat by job_runner1 by setting it to an hour ago
+    # This enables the second TriggererJobRunner to pick up the trigger.
+    for _ in range(20):
+        time.sleep(0.1)
+        if getattr(job_runner1, "load_triggers_count", 0) >= 1:
+            job1.latest_heartbeat = timezone.utcnow() - 
datetime.timedelta(hours=1)
+            session.commit()
+            break
+
+    # Start second TriggererJobRunner.
+    # This TriggererJobRunner will pick up the trigger and try to run it,
+    # but the job_runner1.handle_events already unlinked the trigger from the 
task instance,
+    # so trigger.task_instance is None.
+    job_runner2 = TriggererJobRunner(job2)
+    job_runner2.trigger_runner = TriggerRunnerWithUpdateDelay_()
+    thread2 = Thread(target=job_runner2._execute)
+    thread2.start()
+
+    try:
+        for _ in range(100):
+            time.sleep(0.1)
+            if not thread1.is_alive():
+                pytest.fail("job_runner1 is not alive")
+            if not thread2.is_alive():
+                pytest.fail("job_runner2 is not alive")
+
+            if getattr(job_runner2.trigger_runner, "create_triggers_count", 0) 
>= 1:
+                break
+    finally:
+        job_runner1.trigger_runner.stop = True
+        job_runner1.trigger_runner.join(10)
+        thread1.join()
+
+        job_runner2.trigger_runner.stop = True
+        job_runner2.trigger_runner.join(10)
+        thread2.join()
+
+    assert job_runner1.trigger_runner.trigger_creation_count == 1
+    assert job_runner2.trigger_runner.trigger_creation_count == 0
+
+    instances = path.read_text().splitlines()
+    assert len(instances) == 1

Review Comment:
   This should check the content instead.



##########
tests/jobs/test_triggerer_job.py:
##########
@@ -309,6 +309,140 @@ def test_update_trigger_with_triggerer_argument_change(
         assert "got an unexpected keyword argument 'not_exists_arg'" in 
caplog.text
 
 
+def test_trigger_create_race_condition_38599(session, tmp_path):
+    """
+    This verifies the resolution of race condition documented in github issue 
#38599.
+    More details in the issue description.
+
+    The race condition may occur in the following scenario:
+        1. TaskInstance TI1 defers itself, which creates Trigger T1, which 
holds a
+            reference to TI1.
+        2. T1 gets picked up by TriggererJobRunner TJR1 and starts running T1.
+        3. TJR1 misses a heartbeat, most likely due to high host load causing 
delays in
+            each TriggererJobRunner._run_trigger_loop loop.
+        4. A second TriggererJobRunner TJR2 notices that T1 has missed its 
heartbeat,
+            so it starts the process of picking up any Triggers that TJR1 may 
have had,
+            including T1.
+        5. Before TJR2 starts executing T1, TJR1 finishes execution of T1 and 
cleans it
+            up by clearing the trigger_id of TI1.
+        6. TJR2 tries to execute T1, but it crashes (with the above error) 
while trying to
+            look up TI1 (because T1 no longer has a TaskInstance linked to it).
+    """
+    path = tmp_path / "test_trigger_create_after_completion.txt"
+    trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), 
filename=path.as_posix())
+    trigger_orm = Trigger.from_object(trigger)
+    trigger_orm.id = 1
+    session.add(trigger_orm)
+
+    dag = DagModel(dag_id="test-dag")
+    dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
+    ti = TaskInstance(
+        PythonOperator(task_id="dummy-task", python_callable=print),
+        run_id=dag_run.run_id,
+        state=TaskInstanceState.DEFERRED,
+    )
+    ti.dag_id = dag.dag_id
+    ti.trigger_id = 1
+    session.add(dag)
+    session.add(dag_run)
+    session.add(ti)
+
+    job1 = Job()
+    job2 = Job()
+    session.add(job1)
+    session.add(job2)
+
+    session.commit()
+
+    class TriggerRunnerWithCreateCount_(TriggerRunner):
+        async def create_triggers(self):
+            num_triggers_to_create = len(self.to_create)
+            await super().create_triggers()
+            self.trigger_creation_count = getattr(self, 
"trigger_creation_count", 0) + num_triggers_to_create
+
+    class TriggerRunnerWithUpdateDelay_(TriggerRunnerWithCreateCount_):
+        """TriggerRunner with a 5 second delay added at the beginning of 
update_triggers
+        to increase the window that the race condition may occur.
+        """
+
+        def update_triggers(self, *args, **kwargs):
+            # Delay calling update_triggers to increase the window of 
opportunity
+            time.sleep(5)
+            super().update_triggers(*args, **kwargs)
+
+        async def create_triggers(self):
+            await super().create_triggers()
+            self.create_triggers_count = getattr(self, 
"create_triggers_count", 0) + 1
+
+    class TriggererJobRunner_(TriggererJobRunner):
+        """TriggererJobRunner whose handle_events blocks until there is an 
event."""
+
+        def load_triggers(self):
+            super().load_triggers()
+            self.load_triggers_count = getattr(self, "load_triggers_count", 0) 
+ 1
+
+        def handle_events(self):
+            # Wait for event during the first loop
+            while not self.trigger_runner.events and getattr(self, 
"handle_events_count", 0) == 0:
+                time.sleep(0.1)
+            super().handle_events()
+            self.handle_events_count = getattr(self, "handle_events_count", 0) 
+ 1
+            # Prevent Trigger.clean_unused() from deleting the trigger
+            time.sleep(5)
+
+    # Start first TriggererJobRunner immediately.
+    # This TriggererJobRunner will immediately load the trigger and start 
running it.
+    # Once the trigger is finished, it will, however, stall after the first 
loop in handle_events,
+    # preventing the trigger from being cleaned up. This simulates what may 
happen during high load.
+    job_runner1 = TriggererJobRunner_(job1)
+    job_runner1.trigger_runner = TriggerRunnerWithCreateCount_()
+    thread1 = Thread(target=job_runner1._execute)
+    thread1.start()
+
+    # Simulate a missed heartbeat by job_runner1 by setting it to an hour ago
+    # This enables the second TriggererJobRunner to pick up the trigger.
+    for _ in range(20):
+        time.sleep(0.1)
+        if getattr(job_runner1, "load_triggers_count", 0) >= 1:
+            job1.latest_heartbeat = timezone.utcnow() - 
datetime.timedelta(hours=1)
+            session.commit()
+            break
+
+    # Start second TriggererJobRunner.
+    # This TriggererJobRunner will pick up the trigger and try to run it,
+    # but the job_runner1.handle_events already unlinked the trigger from the 
task instance,
+    # so trigger.task_instance is None.
+    job_runner2 = TriggererJobRunner(job2)
+    job_runner2.trigger_runner = TriggerRunnerWithUpdateDelay_()
+    thread2 = Thread(target=job_runner2._execute)
+    thread2.start()
+
+    try:
+        for _ in range(100):
+            time.sleep(0.1)
+            if not thread1.is_alive():
+                pytest.fail("job_runner1 is not alive")

Review Comment:
   Can this not be `assert thread1.is_alive(), "job_runner1 is not alive"`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to