This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 8892587cce270aa504fc1a9e25d8d2279f0c71b8 Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Sat Jun 18 08:32:38 2022 +0100 Fix mapped task immutability after clear (#23667) We should be able to detect if the structure of mapped task has changed and verify the integrity. This PR ensures this Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> (cherry picked from commit b692517ce3aafb276e9d23570e9734c30a5f3d1f) --- airflow/models/dagrun.py | 114 +++++++++++++++++++++++++------ tests/models/test_dagrun.py | 161 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 251 insertions(+), 24 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index b71cd03eec..3be82b9b6d 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -642,15 +642,9 @@ class DagRun(Base, LoggingMixin): tis = list(self.get_task_instances(session=session, state=State.task_states)) self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) dag = self.get_dag() - for ti in tis: - try: - ti.task = dag.get_task(ti.task_id) - except TaskNotFound: - self.log.warning( - "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id - ) - ti.state = State.REMOVED - session.flush() + missing_indexes = self._find_missing_task_indexes(dag, tis, session=session) + if missing_indexes: + self.verify_integrity(missing_indexes=missing_indexes, session=session) unfinished_tis = [t for t in tis if t.state in State.unfinished] finished_tis = [t for t in tis if t.state in State.finished] @@ -811,11 +805,17 @@ class DagRun(Base, LoggingMixin): Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration) @provide_session - def verify_integrity(self, session: Session = NEW_SESSION): + def verify_integrity( + self, + *, + missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None, + session: Session = NEW_SESSION, + ): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. + :missing_indexes: A dictionary of task vs indexes that are missing. :param session: Sqlalchemy ORM Session """ from airflow.settings import task_instance_mutation_hook @@ -824,9 +824,16 @@ class DagRun(Base, LoggingMixin): hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False) dag = self.get_dag() - task_ids = self._check_for_removed_or_restored_tasks( - dag, task_instance_mutation_hook, session=session - ) + task_ids: Set[str] = set() + if missing_indexes: + tis = self.get_task_instances(session=session) + for ti in tis: + task_instance_mutation_hook(ti) + task_ids.add(ti.task_id) + else: + task_ids, missing_indexes = self._check_for_removed_or_restored_tasks( + dag, task_instance_mutation_hook, session=session + ) def task_filter(task: "Operator") -> bool: return task.task_id not in task_ids and ( @@ -841,27 +848,29 @@ class DagRun(Base, LoggingMixin): task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop) # Create the missing tasks, including mapped tasks - tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session) + tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session) self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session) def _check_for_removed_or_restored_tasks( self, dag: "DAG", ti_mutation_hook, *, session: Session - ) -> Set[str]: + ) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]: """ - Check for removed tasks/restored tasks. + Check for removed tasks/restored/missing tasks. :param dag: DAG object corresponding to the dagrun :param ti_mutation_hook: task_instance_mutation_hook function :param session: Sqlalchemy ORM Session - :return: List of task_ids in the dagrun + :return: List of task_ids in the dagrun and missing task indexes """ tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() + existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list) + expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) for ti in tis: ti_mutation_hook(ti) task_ids.add(ti.task_id) @@ -902,7 +911,8 @@ class DagRun(Base, LoggingMixin): else: self.log.info("Restoring mapped task '%s'", ti) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) - ti.state = State.NONE + existing_indexes[task].append(ti.map_index) + expected_indexes[task] = range(num_mapped_tis) else: # What if it is _now_ dynamically mapped, but wasn't before? total_length = task.run_time_mapped_ti_count(self.run_id, session=session) @@ -923,8 +933,16 @@ class DagRun(Base, LoggingMixin): total_length, ) ti.state = State.REMOVED - ... - return task_ids + else: + self.log.info("Restoring mapped task '%s'", ti) + Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) + existing_indexes[task].append(ti.map_index) + expected_indexes[task] = range(total_length) + # Check if we have some missing indexes to create ti for + missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) + for k, v in existing_indexes.items(): + missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))}) + return task_ids, missing_indexes def _get_task_creator( self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool @@ -961,7 +979,13 @@ class DagRun(Base, LoggingMixin): return creator def _create_missing_tasks( - self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session + self, + dag: "DAG", + task_creator: Callable, + task_filter: Callable, + missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]], + *, + session: Session, ) -> Iterable["Operator"]: """ Create missing tasks -- and expand any MappedOperator that _only_ have literals as input @@ -972,7 +996,9 @@ class DagRun(Base, LoggingMixin): :param session: the session to use """ - def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]: + def expand_mapped_literals( + task: "Operator", sequence: Union[Sequence[int], None] = None + ) -> Tuple["Operator", Sequence[int]]: if not task.is_mapped: return (task, (-1,)) task = cast("MappedOperator", task) @@ -981,11 +1007,19 @@ class DagRun(Base, LoggingMixin): ) if not count: return (task, (-1,)) + if sequence: + return (task, sequence) return (task, range(count)) tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values())) tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs)) + if missing_indexes: + # If there are missing indexes, override the tasks to create + new_tasks_and_map_idxs = itertools.starmap( + expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0] + ) + tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs)) return tasks def _create_task_instances( @@ -1027,6 +1061,42 @@ class DagRun(Base, LoggingMixin): # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() + def _find_missing_task_indexes(self, dag, tis, *, session) -> Dict["MappedOperator", Sequence[int]]: + """ + Here we check if the length of the mapped task instances changed + at runtime. If so, we find the missing indexes. + + This function also marks task instances with missing tasks as REMOVED. + + :param dag: DAG object corresponding to the dagrun + :param tis: task instances to check + :param session: the session to use + """ + existing_indexes: Dict["MappedOperator", list] = defaultdict(list) + new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) + for ti in tis: + try: + task = ti.task = dag.get_task(ti.task_id) + except TaskNotFound: + self.log.error("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id) + + ti.state = State.REMOVED + session.flush() + continue + if not task.is_mapped: + continue + # skip unexpanded tasks and also tasks that expands with literal arguments + if ti.map_index < 0 or task.parse_time_mapped_ti_count: + continue + existing_indexes[task].append(ti.map_index) + task.run_time_mapped_ti_count.cache_clear() + new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0 + new_indexes[task] = range(new_length) + missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) + for k, v in existing_indexes.items(): + missing_indexes.update({k: list(set(new_indexes[k]).difference(v))}) + return missing_indexes + @staticmethod def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']: """ diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index f73f5d1c45..d45fd41370 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -41,7 +41,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE as _DEFAULT_DATE -from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, clear_db_variables from tests.test_utils.mock_operators import MockOperator DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE) @@ -54,11 +54,13 @@ class TestDagRun: clear_db_runs() clear_db_pools() clear_db_dags() + clear_db_variables() def teardown_method(self) -> None: clear_db_runs() clear_db_pools() clear_db_dags() + clear_db_variables() def create_dag_run( self, @@ -899,7 +901,7 @@ def test_verify_integrity_task_start_and_end_date(Stats_incr, session, run_type, session.add(dag_run) session.flush() - dag_run.verify_integrity(session) + dag_run.verify_integrity(session=session) tis = dag_run.task_instances assert len(tis) == expected_tis @@ -1027,6 +1029,161 @@ def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session): ] +def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session): + """Test that when the length of mapped literal increases, additional ti is added""" + + with dag_maker(session=session) as dag: + + @task + def task_2(arg2): + ... + + task_2.expand(arg2=[1, 2, 3, 4]) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances() + indices = [(ti.map_index, ti.state) for ti in tis] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.NONE), + (3, State.NONE), + ] + + # Now "increase" the length of literal + dag._remove_task('task_2') + + with dag: + task_2.expand(arg2=[1, 2, 3, 4, 5]).operator + + # At this point, we need to test that the change works on the serialized + # DAG (which is what the scheduler operates on) + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dr.dag = serialized_dag + # Since we change the literal on the dag file itself, the dag_hash will + # change which will have the scheduler verify the dr integrity + dr.verify_integrity() + + tis = dr.get_task_instances() + indices = [(ti.map_index, ti.state) for ti in tis] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.NONE), + (3, State.NONE), + (4, State.NONE), + ] + + +def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session): + """Test that when the length of mapped literal reduces, removed state is added""" + + with dag_maker(session=session) as dag: + + @task + def task_2(arg2): + ... + + task_2.expand(arg2=[1, 2, 3, 4]) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances() + indices = [(ti.map_index, ti.state) for ti in tis] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.NONE), + (3, State.NONE), + ] + + # Now "reduce" the length of literal + dag._remove_task('task_2') + + with dag: + task_2.expand(arg2=[1, 2]).operator + + # At this point, we need to test that the change works on the serialized + # DAG (which is what the scheduler operates on) + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dr.dag = serialized_dag + # Since we change the literal on the dag file itself, the dag_hash will + # change which will have the scheduler verify the dr integrity + dr.verify_integrity() + + tis = dr.get_task_instances() + indices = [(ti.map_index, ti.state) for ti in tis] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.REMOVED), + (3, State.REMOVED), + ] + + +def test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker, session): + """Test that when the length of mapped literal increases at runtime, additional ti is added""" + from airflow.models import Variable + + Variable.set(key='arg1', value=[1, 2, 3]) + + @task + def task_1(): + return Variable.get('arg1', deserialize_json=True) + + with dag_maker(session=session) as dag: + + @task + def task_2(arg2): + ... + + task_2.expand(arg2=task_1()) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id='task_1') + ti.run() + dr.task_instance_scheduling_decisions() + tis = dr.get_task_instances() + indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.NONE), + ] + + # Now "clear" and "increase" the length of literal + dag.clear() + Variable.set(key='arg1', value=[1, 2, 3, 4]) + + with dag: + task_2.expand(arg2=task_1()).operator + + # At this point, we need to test that the change works on the serialized + # DAG (which is what the scheduler operates on) + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dr.dag = serialized_dag + + # Run the first task again to get the new lengths + ti = dr.get_task_instance(task_id='task_1') + task1 = dag.get_task('task_1') + ti.refresh_from_task(task1) + ti.run() + + # this would be called by the localtask job + dr.task_instance_scheduling_decisions() + tis = dr.get_task_instances() + + indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0] + assert sorted(indices) == [ + (0, State.NONE), + (1, State.NONE), + (2, State.NONE), + (3, State.NONE), + ] + + @pytest.mark.need_serialized_dag def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session): literal = [1, 2, 3, 4]