This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26a9ec6581 When marking future tasks, ensure we don't touch other 
mapped TIs (#23177)
26a9ec6581 is described below

commit 26a9ec65816e3ec7542d63ab4a2a494931a06c9b
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Apr 25 10:03:40 2022 +0100

    When marking future tasks, ensure we don't touch other mapped TIs (#23177)
    
    We had a logic bug where if you selected "Map Index 1" to mark as
    success, the other mapped TIs of that run would get cleared too.
    
    This was because the `partial_dag.clear` was picking up the other
    mapped instances of the task.
    
    Co-authored-by: Jed Cunningham <[email protected]>
---
 airflow/api/common/mark_tasks.py | 20 +++++------
 airflow/models/dag.py            | 21 +++++-------
 tests/models/test_dag.py         | 73 +++++++++++++++++++++++++++++++++++++++-
 tests/test_utils/mapping.py      |  1 +
 4 files changed, 92 insertions(+), 23 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 02e6d2e19b..83bdb2081f 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -127,15 +127,15 @@ def set_state(
         raise ValueError("Received tasks with no DAG")
 
     if execution_date:
-        run_id = dag.get_dagrun(execution_date=execution_date).run_id
+        run_id = dag.get_dagrun(execution_date=execution_date, 
session=session).run_id
     if not run_id:
         raise ValueError("Received tasks with no run_id")
 
-    dag_run_ids = get_run_ids(dag, run_id, future, past)
+    dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
     task_id_map_index_list = list(find_task_relatives(tasks, downstream, 
upstream))
     task_ids = [task_id if isinstance(task_id, str) else task_id[0] for 
task_id in task_id_map_index_list]
 
-    confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
+    confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids, 
session=session))
     confirmed_dates = [info.logical_date for info in confirmed_infos]
 
     sub_dag_run_ids = list(
@@ -261,10 +261,10 @@ def verify_dagruns(
             session.merge(dag_run)
 
 
-def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> 
Iterator[_DagRunInfo]:
-    for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids):
+def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str], session: 
SASession) -> Iterator[_DagRunInfo]:
+    for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, 
session=session):
         dag_run.dag = dag
-        dag_run.verify_integrity()
+        dag_run.verify_integrity(session=session)
         yield _DagRunInfo(dag_run.logical_date, 
dag.get_run_data_interval(dag_run))
 
 
@@ -318,8 +318,8 @@ def get_execution_dates(
 @provide_session
 def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: 
SASession = NEW_SESSION):
     """Returns run_ids of DAG execution"""
-    last_dagrun = dag.get_last_dagrun(include_externally_triggered=True)
-    current_dagrun = dag.get_dagrun(run_id=run_id)
+    last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, 
session=session)
+    current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
     first_dagrun = (
         session.query(DagRun)
         .filter(DagRun.dag_id == dag.dag_id)
@@ -336,7 +336,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: 
bool, session: SASess
     if not dag.timetable.can_run:
         # If the DAG never schedules, need to look at existing DagRun if the 
user wants future or
         # past runs.
-        dag_runs = dag.get_dagruns_between(start_date=start_date, 
end_date=end_date)
+        dag_runs = dag.get_dagruns_between(start_date=start_date, 
end_date=end_date, session=session)
         run_ids = sorted({d.run_id for d in dag_runs})
     elif not dag.timetable.periodic:
         run_ids = [run_id]
@@ -344,7 +344,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: 
bool, session: SASess
         dates = [
             info.logical_date for info in 
dag.iter_dagrun_infos_between(start_date, end_date, align=False)
         ]
-        run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, 
execution_date=dates)]
+        run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, 
execution_date=dates, session=session)]
     return run_ids
 
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 212c126c67..a96c24ca29 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1649,25 +1649,14 @@ class DAG(LoggingMixin):
         if not exactly_one(execution_date, run_id):
             raise ValueError("Exactly one of execution_date or run_id must be 
provided")
 
-        if execution_date is None:
-            dag_run = (
-                session.query(DagRun).filter(DagRun.run_id == run_id, 
DagRun.dag_id == self.dag_id).one()
-            )  # Raises an error if not found
-            resolve_execution_date = dag_run.execution_date
-        else:
-            resolve_execution_date = execution_date
-
         task = self.get_task(task_id)
         task.dag = self
 
         tasks_to_set_state: List[Union[Operator, Tuple[Operator, int]]]
-        task_ids_to_exclude_from_clear: Set[Union[str, Tuple[str, int]]]
         if map_indexes is None:
             tasks_to_set_state = [task]
-            task_ids_to_exclude_from_clear = {task_id}
         else:
             tasks_to_set_state = [(task, map_index) for map_index in 
map_indexes]
-            task_ids_to_exclude_from_clear = {(task_id, map_index) for 
map_index in map_indexes}
 
         altered = set_state(
             tasks=tasks_to_set_state,
@@ -1694,6 +1683,14 @@ class DAG(LoggingMixin):
             include_upstream=False,
         )
 
+        if execution_date is None:
+            dag_run = (
+                session.query(DagRun).filter(DagRun.run_id == run_id, 
DagRun.dag_id == self.dag_id).one()
+            )  # Raises an error if not found
+            resolve_execution_date = dag_run.execution_date
+        else:
+            resolve_execution_date = execution_date
+
         end_date = resolve_execution_date if not future else None
         start_date = resolve_execution_date if not past else None
 
@@ -1705,7 +1702,7 @@ class DAG(LoggingMixin):
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids=task_ids_to_exclude_from_clear,
+            exclude_task_ids={task_id},
         )
 
         return altered
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a4073bf402..9e3c46a602 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -58,7 +58,7 @@ from airflow.timetables.simple import NullTimetable, 
OnceTimetable
 from airflow.utils import timezone
 from airflow.utils.file import list_py_file_paths
 from airflow.utils.session import create_session, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.timezone import datetime as datetime_tz
 from airflow.utils.types import DagRunType
 from airflow.utils.weight_rule import WeightRule
@@ -2310,6 +2310,77 @@ def test_set_task_instance_state(run_id, execution_date, 
session, dag_maker):
     assert {t.key for t in altered} == {('test_set_task_instance_state', 
'task_1', dagrun.run_id, 1, -1)}
 
 
+def test_set_task_instance_state_mapped(dag_maker, session):
+    """Test that when setting an individual mapped TI that the other TIs are 
not affected"""
+    task_id = 't1'
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task
+        def make_arg_lists():
+            return [[1], [2], [{'a': 'b'}]]
+
+        def consumer(value):
+            print(value)
+
+        mapped = PythonOperator.partial(task_id=task_id, dag=dag, 
python_callable=consumer).expand(
+            op_args=make_arg_lists()
+        )
+
+        mapped >> BaseOperator(task_id='downstream')
+
+    dr1 = dag_maker.create_dagrun(
+        run_type=DagRunType.SCHEDULED,
+        state=DagRunState.FAILED,
+    )
+    expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2, 
session=session)
+
+    # set_state(future=True) only applies to scheduled runs
+    dr2 = dag_maker.create_dagrun(
+        run_type=DagRunType.SCHEDULED,
+        state=DagRunState.FAILED,
+        execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+    )
+    expand_mapped_task(mapped, dr2.run_id, "make_arg_lists", length=2, 
session=session)
+
+    session.query(TI).filter_by(dag_id=dag.dag_id).update({'state': 
TaskInstanceState.FAILED})
+
+    ti_query = (
+        session.query(TI.task_id, TI.map_index, TI.run_id, TI.state)
+        .filter(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id, 
'downstream']))
+        .order_by(TI.run_id, TI.task_id, TI.map_index)
+    )
+
+    # Check pre-conditions
+    assert ti_query.all() == [
+        ('downstream', -1, dr1.run_id, TaskInstanceState.FAILED),
+        (task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
+        (task_id, 1, dr1.run_id, TaskInstanceState.FAILED),
+        ('downstream', -1, dr2.run_id, TaskInstanceState.FAILED),
+        (task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
+        (task_id, 1, dr2.run_id, TaskInstanceState.FAILED),
+    ]
+
+    dag.set_task_instance_state(
+        task_id=task_id,
+        map_indexes=[1],
+        future=True,
+        run_id=dr1.run_id,
+        state=TaskInstanceState.SUCCESS,
+        session=session,
+    )
+    assert dr1 in session, "Check session is passed down all the way"
+
+    assert ti_query.all() == [
+        ('downstream', -1, dr1.run_id, None),
+        (task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
+        (task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS),
+        ('downstream', -1, dr2.run_id, None),
+        (task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
+        (task_id, 1, dr2.run_id, TaskInstanceState.SUCCESS),
+    ]
+
+
 @pytest.mark.parametrize(
     "start_date, expected_infos",
     [
diff --git a/tests/test_utils/mapping.py b/tests/test_utils/mapping.py
index c7a27eb76d..e70c874b42 100644
--- a/tests/test_utils/mapping.py
+++ b/tests/test_utils/mapping.py
@@ -41,3 +41,4 @@ def expand_mapped_task(
     session.flush()
 
     mapped.expand_mapped_task(run_id, session=session)
+    mapped.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]

Reply via email to