Nataneljpwd commented on code in PR #61274:
URL: https://github.com/apache/airflow/pull/61274#discussion_r2767500941
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1410,24 +1416,44 @@ def notify_dagrun_state_changed(self, msg: str):
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
@provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
+ def get_first_ti_causing_failure(self, dag: SerializedDAG, session:
Session = NEW_SESSION) -> TI | None:
+ """
+ Get the first task instance that would cause a leaf task to fail the
run.
+ """
+
tis = self.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
+
+ failed_leaf_tis = [
+ ti for ti in self._tis_for_dagrun_state(dag=dag, tis=tis)
+ if ti.state in State.failed_states
+ ]
+
+ if not failed_leaf_tis:
+ return None
+
Review Comment:
Why is the "if" necessary?
Maybe we can just return an empty array at the end if the check yields no
results?
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1410,24 +1416,44 @@ def notify_dagrun_state_changed(self, msg: str):
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
@provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
+ def get_first_ti_causing_failure(self, dag: SerializedDAG, session:
Session = NEW_SESSION) -> TI | None:
+ """
+ Get the first task instance that would cause a leaf task to fail the
run.
+ """
+
tis = self.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
+
+ failed_leaf_tis = [
+ ti for ti in self._tis_for_dagrun_state(dag=dag, tis=tis)
+ if ti.state in State.failed_states
+ ]
+
+ if not failed_leaf_tis:
+ return None
+
if dag.partial:
- tis = [ti for ti in tis if not ti.state == State.NONE]
- # filter out removed tasks
- tis = natsorted(
- (ti for ti in tis if ti.state != TaskInstanceState.REMOVED),
- key=lambda ti: ti.task_id,
- )
- if not tis:
- return None
- ti = tis[-1] # get last TaskInstance of DagRun
- return ti
+ tis = [
+ ti for ti in tis if not ti.state in (
+ State.NONE, TaskInstanceState.REMOVED
+ )
+ ]
+
+ # Collect all task IDs on failure paths
+ failure_path_task_ids = set()
+ for failed_leaf in failed_leaf_tis:
+ leaf_task = dag.get_task(failed_leaf.task_id)
+ upstream_ids = leaf_task.get_flat_relative_ids(upstream=True)
+ failure_path_task_ids.update(upstream_ids)
+ failure_path_task_ids.add(failed_leaf.task_id)
+
+ # Find failed tasks on possible failure paths
+ failed_on_paths = [
+ ti for ti in tis
+ if ti.task_id in failure_path_task_ids and ti.state ==
State.FAILED
+ ]
+
+ return min(failed_on_paths, key=lambda ti: ti.end_date, default=None)
Review Comment:
Why shouldn't we return multiple tasks if a few of them were running
concurrently?
And why filter by end_date and not start_date?
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1287,6 +1290,9 @@ def recalculate(self) -> _UnfinishedStates:
reason="all_tasks_deadlocked",
)
elif dag.has_on_failure_callback:
+ last_finished_ti: TI | None = (
+ max(info.finished_tis, key=lambda ti: ti.end_date,
default=None)
Review Comment:
Why do we make this check over and over again?
Can't we do it once?
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -2282,7 +2282,10 @@ def _schedule_dag_run(
select(TI)
.where(TI.dag_id == dag_run.dag_id)
.where(TI.run_id == dag_run.run_id)
- .where(TI.state.in_(State.unfinished))
+ .where(TI.state.in_(State.unfinished) |
(TI.state.is_(None)))
+ ).all()
+ last_unfinished_ti = (
Review Comment:
What if it was a mapped task? Or what if it was concurrently running tasks,
how do you decide?
As the end date is not the only factor, the start date may be a better
option here, or maybe even return a few tasks if they were running
concurrently, and check by the dependencies of the task rather than solely rely
on end date
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1410,24 +1416,44 @@ def notify_dagrun_state_changed(self, msg: str):
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
@provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
+ def get_first_ti_causing_failure(self, dag: SerializedDAG, session:
Session = NEW_SESSION) -> TI | None:
+ """
+ Get the first task instance that would cause a leaf task to fail the
run.
+ """
+
tis = self.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
+
+ failed_leaf_tis = [
+ ti for ti in self._tis_for_dagrun_state(dag=dag, tis=tis)
+ if ti.state in State.failed_states
+ ]
+
+ if not failed_leaf_tis:
+ return None
+
if dag.partial:
- tis = [ti for ti in tis if not ti.state == State.NONE]
- # filter out removed tasks
- tis = natsorted(
- (ti for ti in tis if ti.state != TaskInstanceState.REMOVED),
- key=lambda ti: ti.task_id,
- )
- if not tis:
- return None
- ti = tis[-1] # get last TaskInstance of DagRun
- return ti
+ tis = [
+ ti for ti in tis if not ti.state in (
+ State.NONE, TaskInstanceState.REMOVED
+ )
+ ]
+
+ # Collect all task IDs on failure paths
+ failure_path_task_ids = set()
+ for failed_leaf in failed_leaf_tis:
+ leaf_task = dag.get_task(failed_leaf.task_id)
+ upstream_ids = leaf_task.get_flat_relative_ids(upstream=True)
+ failure_path_task_ids.update(upstream_ids)
+ failure_path_task_ids.add(failed_leaf.task_id)
+
+ # Find failed tasks on possible failure paths
+ failed_on_paths = [
+ ti for ti in tis
+ if ti.task_id in failure_path_task_ids and ti.state ==
State.FAILED
+ ]
Review Comment:
This part confuses me a little, why do we get ALL task instances (`tis`),
just to get all of the previous task instances from the failed tasks
(`failure_path_task_ids`), just to get the failed tasks instances again?
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1410,24 +1416,44 @@ def notify_dagrun_state_changed(self, msg: str):
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
@provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
+ def get_first_ti_causing_failure(self, dag: SerializedDAG, session:
Session = NEW_SESSION) -> TI | None:
+ """
+ Get the first task instance that would cause a leaf task to fail the
run.
+ """
+
tis = self.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
+
+ failed_leaf_tis = [
Review Comment:
So leaf here is meant as if the leaf task for a given dagrun?
If so maybe a better name could be `last_failed_tasks`
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1410,24 +1416,44 @@ def notify_dagrun_state_changed(self, msg: str):
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
@provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
+ def get_first_ti_causing_failure(self, dag: SerializedDAG, session:
Session = NEW_SESSION) -> TI | None:
Review Comment:
Minor nit but aybe the name of the method can be changed to
`get_first_ti_causing_dagrun_failure`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]