ferruzzi commented on code in PR #43941:
URL: https://github.com/apache/airflow/pull/43941#discussion_r2004447417


##########
tests/jobs/test_scheduler_job.py:
##########
@@ -2367,6 +2370,194 @@ def 
test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses
         dag_runs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(dag_runs) == 2
 
+    @pytest.mark.parametrize(
+        "ti_state, final_ti_span_status",
+        [(State.SUCCESS, SpanStatus.ENDED), (State.RUNNING, 
SpanStatus.ACTIVE)],
+    )
+    def test_recreate_unhealthy_scheduler_spans_if_needed(self, ti_state, 
final_ti_span_status, dag_maker):
+        with dag_maker(
+            dag_id="test_recreate_unhealthy_scheduler_spans_if_needed",
+            start_date=DEFAULT_DATE,
+            max_active_runs=1,
+            dagrun_timeout=datetime.timedelta(seconds=60),
+        ):
+            EmptyOperator(task_id="dummy")
+
+        session = settings.Session()
+
+        old_job = Job()
+        old_job.id = 1
+        old_job.job_type = SchedulerJobRunner.job_type
+
+        session.add(old_job)
+        session.commit()
+
+        assert old_job.is_alive() is False
+
+        new_job = Job()
+        new_job.id = 2
+        new_job.job_type = SchedulerJobRunner.job_type
+
+        self.job_runner = SchedulerJobRunner(job=new_job)
+        self.job_runner.active_spans = ThreadSafeDict()
+        assert len(self.job_runner.active_spans.get_all()) == 0
+
+        dr = dag_maker.create_dagrun(external_trigger=True)
+        dr.state = State.RUNNING
+        dr.span_status = SpanStatus.ACTIVE
+        dr.scheduled_by_job_id = old_job.id
+
+        ti = dr.get_task_instances(session=session)[0]
+        ti.state = ti_state
+        ti.start_date = timezone.utcnow()
+        ti.span_status = SpanStatus.ACTIVE
+        ti.queued_by_job_id = old_job.id
+        session.merge(ti)
+        session.merge(dr)
+        session.commit()
+
+        # Given
+        assert dr.scheduled_by_job_id != self.job_runner.job.id
+        assert dr.scheduled_by_job_id == old_job.id
+        assert dr.run_id is not None
+        assert dr.state == State.RUNNING
+        assert dr.span_status == SpanStatus.ACTIVE
+        assert self.job_runner.active_spans.get(dr.run_id) is None
+
+        assert self.job_runner.active_spans.get(ti.key) is None
+        assert ti.state == ti_state
+        assert ti.span_status == SpanStatus.ACTIVE
+
+        # When
+        self.job_runner._recreate_unhealthy_scheduler_spans_if_needed(dr, 
session)
+
+        # Then
+        assert self.job_runner.active_spans.get(dr.run_id) is not None
+
+        if final_ti_span_status == SpanStatus.ACTIVE:
+            assert self.job_runner.active_spans.get(ti.key) is not None
+            assert len(self.job_runner.active_spans.get_all()) == 2
+        else:
+            assert self.job_runner.active_spans.get(ti.key) is None
+            assert len(self.job_runner.active_spans.get_all()) == 1
+
+        assert dr.span_status == SpanStatus.ACTIVE
+        assert ti.span_status == final_ti_span_status
+
+    def test_end_spans_of_externally_ended_ops(self, dag_maker):
+        with dag_maker(
+            dag_id="test_end_spans_of_externally_ended_ops",
+            start_date=DEFAULT_DATE,
+            max_active_runs=1,
+            dagrun_timeout=datetime.timedelta(seconds=60),
+        ):
+            EmptyOperator(task_id="dummy")
+
+        session = settings.Session()
+
+        job = Job()
+        job.id = 1
+        job.job_type = SchedulerJobRunner.job_type
+
+        self.job_runner = SchedulerJobRunner(job=job)
+        self.job_runner.active_spans = ThreadSafeDict()
+        assert len(self.job_runner.active_spans.get_all()) == 0
+
+        dr = dag_maker.create_dagrun(external_trigger=True)
+        dr.state = State.SUCCESS
+        dr.span_status = SpanStatus.SHOULD_END
+
+        ti = dr.get_task_instances(session=session)[0]
+        ti.state = State.SUCCESS
+        ti.span_status = SpanStatus.SHOULD_END
+        ti.context_carrier = {}
+        session.merge(ti)
+        session.merge(dr)
+        session.commit()
+
+        dr_span = Trace.start_root_span(span_name="dag_run_span", 
start_as_current=False)
+        ti_span = Trace.start_child_span(span_name="ti_span", 
start_as_current=False)
+
+        self.job_runner.active_spans.set(dr.run_id, dr_span)
+        self.job_runner.active_spans.set(ti.key, ti_span)
+
+        # Given
+        assert dr.span_status == SpanStatus.SHOULD_END
+        assert ti.span_status == SpanStatus.SHOULD_END
+
+        assert self.job_runner.active_spans.get(dr.run_id) is not None
+        assert self.job_runner.active_spans.get(ti.key) is not None
+
+        # When
+        self.job_runner._end_spans_of_externally_ended_ops(session)
+
+        # Then
+        assert dr.span_status == SpanStatus.ENDED
+        assert ti.span_status == SpanStatus.ENDED
+
+        assert self.job_runner.active_spans.get(dr.run_id) is None
+        assert self.job_runner.active_spans.get(ti.key) is None
+
+    @pytest.mark.parametrize(
+        "state, final_span_status",
+        [(State.SUCCESS, SpanStatus.ENDED), (State.RUNNING, 
SpanStatus.NEEDS_CONTINUANCE)],
+    )
+    def test_end_active_spans(self, state, final_span_status, dag_maker):
+        with dag_maker(
+            dag_id="test_end_active_spans",
+            start_date=DEFAULT_DATE,
+            max_active_runs=1,
+            dagrun_timeout=datetime.timedelta(seconds=60),
+        ):
+            EmptyOperator(task_id="dummy")
+
+        session = settings.Session()
+
+        job = Job()
+        job.id = 1
+        job.job_type = SchedulerJobRunner.job_type
+
+        self.job_runner = SchedulerJobRunner(job=job)
+        self.job_runner.active_spans = ThreadSafeDict()
+        assert len(self.job_runner.active_spans.get_all()) == 0
+
+        dr = dag_maker.create_dagrun(external_trigger=True)
+        dr.state = state
+        dr.span_status = SpanStatus.ACTIVE
+
+        ti = dr.get_task_instances(session=session)[0]
+        ti.state = state
+        ti.span_status = SpanStatus.ACTIVE
+        ti.context_carrier = {}
+        session.merge(ti)
+        session.merge(dr)
+        session.commit()
+
+        dr_span = Trace.start_root_span(span_name="dag_run_span", 
start_as_current=False)
+        ti_span = Trace.start_child_span(span_name="ti_span", 
start_as_current=False)
+
+        self.job_runner.active_spans.set(dr.run_id, dr_span)
+        self.job_runner.active_spans.set(ti.key, ti_span)
+
+        # Given
+        assert dr.span_status == SpanStatus.ACTIVE
+        assert ti.span_status == SpanStatus.ACTIVE
+
+        assert self.job_runner.active_spans.get(dr.run_id) is not None
+        assert self.job_runner.active_spans.get(ti.key) is not None
+        assert len(self.job_runner.active_spans.get_all()) == 2
+
+        # When
+        self.job_runner._end_active_spans(session)
+
+        # Then

Review Comment:
   I totally get it, I used to put them in as well but dropped the habit to 
match the Airflow style.  To be fair, now that I am used to it, a blank line 
really does do pretty much the same thing, so it's not a big deal.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to