This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 26a9ec6581 When marking future tasks, ensure we don't touch other
mapped TIs (#23177)
26a9ec6581 is described below
commit 26a9ec65816e3ec7542d63ab4a2a494931a06c9b
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Apr 25 10:03:40 2022 +0100
When marking future tasks, ensure we don't touch other mapped TIs (#23177)
We had a logic bug where if you selected "Map Index 1" to mark as
success, the other mapped TIs of that run would get cleared too.
This was because the `partial_dag.clear` was picking up the other
mapped instances of the task.
Co-authored-by: Jed Cunningham <[email protected]>
---
airflow/api/common/mark_tasks.py | 20 +++++------
airflow/models/dag.py | 21 +++++-------
tests/models/test_dag.py | 73 +++++++++++++++++++++++++++++++++++++++-
tests/test_utils/mapping.py | 1 +
4 files changed, 92 insertions(+), 23 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 02e6d2e19b..83bdb2081f 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -127,15 +127,15 @@ def set_state(
raise ValueError("Received tasks with no DAG")
if execution_date:
- run_id = dag.get_dagrun(execution_date=execution_date).run_id
+ run_id = dag.get_dagrun(execution_date=execution_date,
session=session).run_id
if not run_id:
raise ValueError("Received tasks with no run_id")
- dag_run_ids = get_run_ids(dag, run_id, future, past)
+ dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
task_id_map_index_list = list(find_task_relatives(tasks, downstream,
upstream))
task_ids = [task_id if isinstance(task_id, str) else task_id[0] for
task_id in task_id_map_index_list]
- confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
+ confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids,
session=session))
confirmed_dates = [info.logical_date for info in confirmed_infos]
sub_dag_run_ids = list(
@@ -261,10 +261,10 @@ def verify_dagruns(
session.merge(dag_run)
-def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) ->
Iterator[_DagRunInfo]:
- for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids):
+def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str], session:
SASession) -> Iterator[_DagRunInfo]:
+ for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids,
session=session):
dag_run.dag = dag
- dag_run.verify_integrity()
+ dag_run.verify_integrity(session=session)
yield _DagRunInfo(dag_run.logical_date,
dag.get_run_data_interval(dag_run))
@@ -318,8 +318,8 @@ def get_execution_dates(
@provide_session
def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session:
SASession = NEW_SESSION):
"""Returns run_ids of DAG execution"""
- last_dagrun = dag.get_last_dagrun(include_externally_triggered=True)
- current_dagrun = dag.get_dagrun(run_id=run_id)
+ last_dagrun = dag.get_last_dagrun(include_externally_triggered=True,
session=session)
+ current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
first_dagrun = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id)
@@ -336,7 +336,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past:
bool, session: SASess
if not dag.timetable.can_run:
# If the DAG never schedules, need to look at existing DagRun if the
user wants future or
# past runs.
- dag_runs = dag.get_dagruns_between(start_date=start_date,
end_date=end_date)
+ dag_runs = dag.get_dagruns_between(start_date=start_date,
end_date=end_date, session=session)
run_ids = sorted({d.run_id for d in dag_runs})
elif not dag.timetable.periodic:
run_ids = [run_id]
@@ -344,7 +344,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past:
bool, session: SASess
dates = [
info.logical_date for info in
dag.iter_dagrun_infos_between(start_date, end_date, align=False)
]
- run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id,
execution_date=dates)]
+ run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id,
execution_date=dates, session=session)]
return run_ids
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 212c126c67..a96c24ca29 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1649,25 +1649,14 @@ class DAG(LoggingMixin):
if not exactly_one(execution_date, run_id):
raise ValueError("Exactly one of execution_date or run_id must be
provided")
- if execution_date is None:
- dag_run = (
- session.query(DagRun).filter(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id).one()
- ) # Raises an error if not found
- resolve_execution_date = dag_run.execution_date
- else:
- resolve_execution_date = execution_date
-
task = self.get_task(task_id)
task.dag = self
tasks_to_set_state: List[Union[Operator, Tuple[Operator, int]]]
- task_ids_to_exclude_from_clear: Set[Union[str, Tuple[str, int]]]
if map_indexes is None:
tasks_to_set_state = [task]
- task_ids_to_exclude_from_clear = {task_id}
else:
tasks_to_set_state = [(task, map_index) for map_index in
map_indexes]
- task_ids_to_exclude_from_clear = {(task_id, map_index) for
map_index in map_indexes}
altered = set_state(
tasks=tasks_to_set_state,
@@ -1694,6 +1683,14 @@ class DAG(LoggingMixin):
include_upstream=False,
)
+ if execution_date is None:
+ dag_run = (
+ session.query(DagRun).filter(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id).one()
+ ) # Raises an error if not found
+ resolve_execution_date = dag_run.execution_date
+ else:
+ resolve_execution_date = execution_date
+
end_date = resolve_execution_date if not future else None
start_date = resolve_execution_date if not past else None
@@ -1705,7 +1702,7 @@ class DAG(LoggingMixin):
only_failed=True,
session=session,
# Exclude the task itself from being cleared
- exclude_task_ids=task_ids_to_exclude_from_clear,
+ exclude_task_ids={task_id},
)
return altered
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a4073bf402..9e3c46a602 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -58,7 +58,7 @@ from airflow.timetables.simple import NullTimetable,
OnceTimetable
from airflow.utils import timezone
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
@@ -2310,6 +2310,77 @@ def test_set_task_instance_state(run_id, execution_date,
session, dag_maker):
assert {t.key for t in altered} == {('test_set_task_instance_state',
'task_1', dagrun.run_id, 1, -1)}
+def test_set_task_instance_state_mapped(dag_maker, session):
+ """Test that when setting an individual mapped TI that the other TIs are
not affected"""
+ task_id = 't1'
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task
+ def make_arg_lists():
+ return [[1], [2], [{'a': 'b'}]]
+
+ def consumer(value):
+ print(value)
+
+ mapped = PythonOperator.partial(task_id=task_id, dag=dag,
python_callable=consumer).expand(
+ op_args=make_arg_lists()
+ )
+
+ mapped >> BaseOperator(task_id='downstream')
+
+ dr1 = dag_maker.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=DagRunState.FAILED,
+ )
+ expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2,
session=session)
+
+ # set_state(future=True) only applies to scheduled runs
+ dr2 = dag_maker.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ state=DagRunState.FAILED,
+ execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+ )
+ expand_mapped_task(mapped, dr2.run_id, "make_arg_lists", length=2,
session=session)
+
+ session.query(TI).filter_by(dag_id=dag.dag_id).update({'state':
TaskInstanceState.FAILED})
+
+ ti_query = (
+ session.query(TI.task_id, TI.map_index, TI.run_id, TI.state)
+ .filter(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id,
'downstream']))
+ .order_by(TI.run_id, TI.task_id, TI.map_index)
+ )
+
+ # Check pre-conditions
+ assert ti_query.all() == [
+ ('downstream', -1, dr1.run_id, TaskInstanceState.FAILED),
+ (task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
+ (task_id, 1, dr1.run_id, TaskInstanceState.FAILED),
+ ('downstream', -1, dr2.run_id, TaskInstanceState.FAILED),
+ (task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
+ (task_id, 1, dr2.run_id, TaskInstanceState.FAILED),
+ ]
+
+ dag.set_task_instance_state(
+ task_id=task_id,
+ map_indexes=[1],
+ future=True,
+ run_id=dr1.run_id,
+ state=TaskInstanceState.SUCCESS,
+ session=session,
+ )
+ assert dr1 in session, "Check session is passed down all the way"
+
+ assert ti_query.all() == [
+ ('downstream', -1, dr1.run_id, None),
+ (task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
+ (task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS),
+ ('downstream', -1, dr2.run_id, None),
+ (task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
+ (task_id, 1, dr2.run_id, TaskInstanceState.SUCCESS),
+ ]
+
+
@pytest.mark.parametrize(
"start_date, expected_infos",
[
diff --git a/tests/test_utils/mapping.py b/tests/test_utils/mapping.py
index c7a27eb76d..e70c874b42 100644
--- a/tests/test_utils/mapping.py
+++ b/tests/test_utils/mapping.py
@@ -41,3 +41,4 @@ def expand_mapped_task(
session.flush()
mapped.expand_mapped_task(run_id, session=session)
+ mapped.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]