This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 03676ff07035340e10e8e19ead1be7c9a84ae3b3 Author: Tzu-ping Chung <[email protected]> AuthorDate: Tue Apr 19 09:16:04 2022 +0800 Accept multiple map_index param from front end This allows setting multiple instances of the same task to SUCCESS or FAILED in one request. This is translated to multiple task specifier tuples (task_id, map_index) when passed to set_state(). Also made some drive-through improvements adding types and clean some formatting up. --- airflow/api/common/mark_tasks.py | 4 +- airflow/models/dag.py | 12 ++--- airflow/www/views.py | 105 ++++++++++++++++++++++----------------- 3 files changed, 68 insertions(+), 53 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 594423305c..349b935e82 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -18,7 +18,7 @@ """Marks tasks APIs.""" from datetime import datetime -from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union from sqlalchemy import or_, tuple_ from sqlalchemy.orm import contains_eager @@ -78,7 +78,7 @@ def _create_dagruns( @provide_session def set_state( *, - tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]], + tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]], run_id: Optional[str] = None, execution_date: Optional[datetime] = None, upstream: bool = False, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 755505b5d0..9c93bcef13 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1620,7 +1620,7 @@ class DAG(LoggingMixin): self, *, task_id: str, - map_index: Optional[int] = None, + map_indexes: Optional[Collection[int]] = None, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, state: TaskInstanceState, @@ -1636,8 +1636,8 @@ class DAG(LoggingMixin): in failed or upstream_failed state. :param task_id: Task ID of the TaskInstance - :param map_index: The TaskInstance map_index, if None, would set state for all mapped - TaskInstances of the task + :param map_indexes: Only set TaskInstance if its map_index matches. + If None (default), all mapped TaskInstances of the task are set. :param execution_date: Execution date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to @@ -1665,12 +1665,12 @@ class DAG(LoggingMixin): tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]] task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]] - if map_index is None: + if map_indexes is None: tasks_to_set_state = [task] task_ids_to_exclude_from_clear = {task_id} else: - tasks_to_set_state = [(task, map_index)] - task_ids_to_exclude_from_clear = {(task_id, map_index)} + tasks_to_set_state = [(task, map_index) for map_index in map_indexes] + task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes} altered = set_state( tasks=tasks_to_set_state, diff --git a/airflow/www/views.py b/airflow/www/views.py index 437a60cca0..aac30f64ff 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -95,6 +95,7 @@ from airflow.api.common.mark_tasks import ( set_dag_run_state_to_failed, set_dag_run_state_to_queued, set_dag_run_state_to_success, + set_state, ) from airflow.compat.functools import cached_property from airflow.configuration import AIRFLOW_CONFIG, conf @@ -107,6 +108,7 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task from airflow.models.abstractoperator import AbstractOperator from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun, DagRunType +from airflow.models.operator import Operator from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers_manager import ProvidersManager @@ -2284,28 +2286,28 @@ class Airflow(AirflowBaseView): def _mark_task_instance_state( self, - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - state, - map_index=None, + *, + dag_id: str, + run_id: str, + task_id: str, + map_indexes: Optional[List[int]], + origin: str, + upstream: bool, + downstream: bool, + future: bool, + past: bool, + state: TaskInstanceState, ): dag = current_app.dag_bag.get_dag(dag_id) - latest_execution_date = dag.get_latest_execution_date() - if not latest_execution_date: - flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error") + if not run_id: + flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error") return redirect(origin) altered = dag.set_task_instance_state( task_id=task_id, - map_index=map_index, - run_id=dag_run_id, + map_indexes=map_indexes, + run_id=run_id, state=state, upstream=upstream, downstream=downstream, @@ -2332,7 +2334,11 @@ class Airflow(AirflowBaseView): dag_run_id = args.get('dag_run_id') state = args.get('state') origin = args.get('origin') - map_index = args.get('map_index') + + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2365,9 +2371,10 @@ class Airflow(AirflowBaseView): msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run" return redirect_or_json(origin, msg, status='error') - from airflow.api.common.mark_tasks import set_state - - tasks = [(task, map_index)] if map_index else [task] + if map_indexes is None: + tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task] + else: + tasks = [(task, map_index) for map_index in map_indexes] to_be_altered = set_state( tasks=tasks, @@ -2408,26 +2415,30 @@ class Airflow(AirflowBaseView): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - origin = get_safe_url(args.get('origin')) - dag_run_id = args.get('dag_run_id') - map_index = args.get('map_index') + run_id = args.get('dag_run_id') + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) + + origin = get_safe_url(args.get('origin')) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - State.FAILED, - map_index=map_index, + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_indexes=map_indexes, + origin=origin, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=TaskInstanceState.FAILED, ) @expose('/success', methods=['POST']) @@ -2443,26 +2454,30 @@ class Airflow(AirflowBaseView): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - origin = get_safe_url(args.get('origin')) - dag_run_id = args.get('dag_run_id') - map_index = args.get('map_index') + run_id = args.get('dag_run_id') + + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) + origin = get_safe_url(args.get('origin')) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - State.SUCCESS, - map_index=map_index, + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_indexes=map_indexes, + origin=origin, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=TaskInstanceState.SUCCESS, ) @expose('/dags/<string:dag_id>')
