dstandish commented on code in PR #33903: URL: https://github.com/apache/airflow/pull/33903#discussion_r1312636933
########## airflow/ti_deps/deps/trigger_rule_dep.py: ########## @@ -222,224 +212,356 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]: else: yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) - # Optimization: Don't need to hit the database if all upstreams are - # "simple" tasks (no task or task group mapping involved). - if not any(needs_expansion(t) for t in upstream_tasks.values()): - upstream = len(upstream_tasks) - upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) - else: - task_id_counts = session.execute( - select(TaskInstance.task_id, func.count(TaskInstance.task_id)) - .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) - .where(or_(*_iter_upstream_conditions())) - .group_by(TaskInstance.task_id) - ).all() - upstream = sum(count for _, count in task_id_counts) - upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) - - upstream_done = done >= upstream - - changed = False - new_state = None - if dep_context.flag_upstream_failed: - if trigger_rule == TR.ALL_SUCCESS: + def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]: + """Evaluate whether ``ti``'s trigger rule was met. + + :param ti: Task instance to evaluate the trigger rule of. + :param dep_context: The current dependency context. + :param session: Database session. + """ + task = ti.task + + indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids} + finished_upstream_tis = ( + x + for x in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) + if _is_relevant_upstream(upstream=x, relevant_ids=indirect_setups.keys()) + ) + upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) + + # all of these counts reflect indirect setups which are relevant for this ti + success = upstream_states.success + skipped = upstream_states.skipped + failed = upstream_states.failed + upstream_failed = upstream_states.upstream_failed + removed = upstream_states.removed + + # Optimization: Don't need to hit the database if all upstreams are + # "simple" tasks (no task or task group mapping involved). + if not any(needs_expansion(t) for t in indirect_setups.values()): + upstream = len(indirect_setups) + else: + task_id_counts = session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=indirect_setups))) + .group_by(TaskInstance.task_id) + ).all() + upstream = sum(count for _, count in task_id_counts) + + new_state = None + changed = False + + # if there's a failure, we mark upstream_failed; if there's a skip, we mark skipped + # in either case, we don't wait for all relevant setups to complete + if dep_context.flag_upstream_failed: if upstream_failed or failed: new_state = TaskInstanceState.UPSTREAM_FAILED elif skipped: new_state = TaskInstanceState.SKIPPED elif removed and success and ti.map_index > -1: if ti.map_index >= success: new_state = TaskInstanceState.REMOVED - elif trigger_rule == TR.ALL_FAILED: - if success or skipped: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ONE_SUCCESS: - if upstream_done and done == skipped: - # if upstream is done and all are skipped mark as skipped - new_state = TaskInstanceState.SKIPPED - elif upstream_done and success <= 0: - # if upstream is done and there are no success mark as upstream failed - new_state = TaskInstanceState.UPSTREAM_FAILED - elif trigger_rule == TR.ONE_FAILED: - if upstream_done and not (failed or upstream_failed): - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ONE_DONE: - if upstream_done and not (failed or success): - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.NONE_FAILED: - if upstream_failed or failed: - new_state = TaskInstanceState.UPSTREAM_FAILED - elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - if upstream_failed or failed: - new_state = TaskInstanceState.UPSTREAM_FAILED - elif skipped == upstream: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.NONE_SKIPPED: - if skipped: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ALL_SKIPPED: - if success or failed: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: - if upstream_done and upstream_setup and skipped_setup >= upstream_setup: - # when there is an upstream setup and they have all skipped, then skip - new_state = TaskInstanceState.SKIPPED - elif upstream_done and upstream_setup and success_setup == 0: - # when there is an upstream setup, if none succeeded, mark upstream failed - # if at least one setup ran, we'll let it run - new_state = TaskInstanceState.UPSTREAM_FAILED - if new_state is not None: - if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping: - past_depends_met = ti.xcom_pull( - task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False - ) - if not past_depends_met: - yield self._failing_status( - reason=("Task should be skipped but the past depends are not met") + + if new_state is not None: + if ( + new_state == TaskInstanceState.SKIPPED + and dep_context.wait_for_past_depends_before_skipping + ): + past_depends_met = ti.xcom_pull( + task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False ) - return - changed = ti.set_state(new_state, session) + if not past_depends_met: + yield self._failing_status( + reason="Task should be skipped but the past depends are not met" + ), changed + return + changed = ti.set_state(new_state, session) - if changed: - dep_context.have_changed_ti_states = True + if changed: + dep_context.have_changed_ti_states = True - if trigger_rule == TR.ONE_SUCCESS: - if success <= 0: + non_successes = upstream - success + if ti.map_index > -1: + non_successes -= removed + if non_successes > 0: yield self._failing_status( reason=( - f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " - f"but none were found. upstream_states={upstream_states}, " + f"All setup tasks must complete successfully. Relevant setups: {relevant_setups}: " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" + ), + ), changed + + def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: + """Evaluate whether ``ti``'s trigger rule was met. + + :param ti: Task instance to evaluate the trigger rule of. + :param dep_context: The current dependency context. + :param session: Database session. + """ + finished_upstream_tis = ( + finished_ti + for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) + if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids) + ) + upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) + + success = upstream_states.success + skipped = upstream_states.skipped + failed = upstream_states.failed + upstream_failed = upstream_states.upstream_failed + removed = upstream_states.removed + done = upstream_states.done + success_setup = upstream_states.success_setup + skipped_setup = upstream_states.skipped_setup + + # Optimization: Don't need to hit the database if all upstreams are + # "simple" tasks (no task or task group mapping involved). + if not any(needs_expansion(t) for t in upstream_tasks.values()): + upstream = len(upstream_tasks) + upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) + else: + task_id_counts = session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) + .group_by(TaskInstance.task_id) + ).all() + upstream = sum(count for _, count in task_id_counts) + upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) + + upstream_done = done >= upstream + + changed = False + new_state = None + if dep_context.flag_upstream_failed: + if trigger_rule == TR.ALL_SUCCESS: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif skipped: + new_state = TaskInstanceState.SKIPPED + elif removed and success and ti.map_index > -1: + if ti.map_index >= success: + new_state = TaskInstanceState.REMOVED + elif trigger_rule == TR.ALL_FAILED: + if success or skipped: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ONE_SUCCESS: + if upstream_done and done == skipped: + # if upstream is done and all are skipped mark as skipped + new_state = TaskInstanceState.SKIPPED + elif upstream_done and success <= 0: + # if upstream is done and there are no success mark as upstream failed + new_state = TaskInstanceState.UPSTREAM_FAILED + elif trigger_rule == TR.ONE_FAILED: + if upstream_done and not (failed or upstream_failed): + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ONE_DONE: + if upstream_done and not (failed or success): + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.NONE_FAILED: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif skipped == upstream: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.NONE_SKIPPED: + if skipped: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ALL_SKIPPED: + if success or failed: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: + if upstream_done and upstream_setup and skipped_setup >= upstream_setup: + # when there is an upstream setup and they have all skipped, then skip + new_state = TaskInstanceState.SKIPPED + elif upstream_done and upstream_setup and success_setup == 0: + # when there is an upstream setup, if none succeeded, mark upstream failed + # if at least one setup ran, we'll let it run + new_state = TaskInstanceState.UPSTREAM_FAILED + if new_state is not None: + if ( + new_state == TaskInstanceState.SKIPPED + and dep_context.wait_for_past_depends_before_skipping + ): + past_depends_met = ti.xcom_pull( + task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False ) - ) - elif trigger_rule == TR.ONE_FAILED: - if not failed and not upstream_failed: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " - f"but none were found. upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + if not past_depends_met: + yield self._failing_status( + reason=("Task should be skipped but the past depends are not met") + ) + return + changed = ti.set_state(new_state, session) + + if changed: + dep_context.have_changed_ti_states = True + + if trigger_rule == TR.ONE_SUCCESS: + if success <= 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " + f"but none were found. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ONE_DONE: - if success + failed <= 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}'" - "requires at least one upstream task failure or success" - f"but none were failed or success. upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ONE_FAILED: + if not failed and not upstream_failed: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " + f"but none were found. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_SUCCESS: - num_failures = upstream - success - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ONE_DONE: + if success + failed <= 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}'" + "requires at least one upstream task failure or success" + f"but none were failed or success. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_FAILED: - num_success = upstream - failed - upstream_failed - if ti.map_index > -1: - num_success -= removed - if num_success > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, " - f"but found {num_success} non-failure(s). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_SUCCESS: + num_failures = upstream - success + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_DONE: - if not upstream_done: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_FAILED: + num_success = upstream - failed - upstream_failed + if ti.map_index > -1: + num_success -= removed + if num_success > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks " + f"to have failed, but found {num_success} non-failure(s). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_FAILED: - num_failures = upstream - success - skipped - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_DONE: + if not upstream_done: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"completed, but found {len(upstream_tasks) - done} task(s) that were " + f"not done. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - num_failures = upstream - success - skipped - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_FAILED: + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded or been skipped, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_SKIPPED: - if not upstream_done or (skipped > 0): - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been " - f"skipped, but found {skipped} task(s) skipped. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded or been skipped, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_SKIPPED: - num_non_skipped = upstream - skipped - if num_non_skipped > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " - f"skipped, but found {num_non_skipped} task(s) in non skipped state. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_SKIPPED: + if not upstream_done or (skipped > 0): + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not " + f"have been skipped, but found {skipped} task(s) skipped. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: - if not upstream_done: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_SKIPPED: + num_non_skipped = upstream - skipped + if num_non_skipped > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " + f"skipped, but found {num_non_skipped} task(s) in non skipped state. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif upstream_setup is None: # for now, None only happens in mapped case - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: + if not upstream_done: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif upstream_setup and not success_setup >= 1: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task be " - f"successful, but found {upstream_setup - success_setup} task(s) that were not. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif upstream_setup is None: # for now, None only happens in mapped case + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) + ) + elif upstream_setup and not success_setup >= 1: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task " + f"be successful, but found {upstream_setup - success_setup} task(s) that were " + f"not. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - else: - yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") + else: + yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") + + # this dictionary enables use of task_id as the key for lru caching + # in function _get_relevant_upstream_map_indexes + tasks_dict = {} Review Comment: no, cus the bug we're solving here, is the case of _indirect_ setups -- setups that are not direct upstream relatives. in normal operation this is a non issue. but when you clear a work task, it clears its setups; if you don't add a check for the indirect setups then they will run concurrently. that's the problem here. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org