This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 5620df94692a8802202789b21a05104999f8494c Author: Ephraim Anierobi <[email protected]> AuthorDate: Tue Apr 12 16:56:11 2022 +0100 fixup! Allow marking/clearing mapped taskinstances from the UI --- airflow/www/views.py | 168 ++++++++++++++++++++++++--------------------------- 1 file changed, 80 insertions(+), 88 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index ae0186e493..7be1289144 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -95,7 +95,6 @@ from airflow.api.common.mark_tasks import ( set_dag_run_state_to_failed, set_dag_run_state_to_queued, set_dag_run_state_to_success, - set_state, ) from airflow.compat.functools import cached_property from airflow.configuration import AIRFLOW_CONFIG, conf @@ -108,7 +107,6 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task from airflow.models.abstractoperator import AbstractOperator from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun, DagRunType -from airflow.models.operator import Operator from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers_manager import ProvidersManager @@ -1960,11 +1958,11 @@ class Airflow(AirflowBaseView): def _clear_dag_tis( self, - dag: DAG, + dag, start_date, end_date, origin, - task_ids=None, + map_indexes=None, recursive=False, confirmed=False, only_failed=False, @@ -1973,7 +1971,7 @@ class Airflow(AirflowBaseView): count = dag.clear( start_date=start_date, end_date=end_date, - task_ids=task_ids, + map_indexes=map_indexes, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1986,7 +1984,7 @@ class Airflow(AirflowBaseView): tis = dag.clear( start_date=start_date, end_date=end_date, - task_ids=task_ids, + map_indexes=map_indexes, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1995,19 +1993,24 @@ class Airflow(AirflowBaseView): except AirflowException as ex: return redirect_or_json(origin, msg=str(ex), status="error") - assert isinstance(tis, collections.abc.Iterable) - details = [str(t) for t in tis] - - if not details: - return redirect_or_json(origin, "No task instances to clear", status="error") + if not tis: + msg = "No task instances to clear" + return redirect_or_json(origin, msg, status="error") elif request.headers.get('Accept') == 'application/json': + details = [str(t) for t in tis] + return htmlsafe_json_dumps(details, separators=(',', ':')) - return self.render_template( - 'airflow/confirm.html', - endpoint=None, - message="Task instances you are about to clear:", - details="\n".join(details), - ) + else: + details = "\n".join(str(t) for t in tis) + + response = self.render_template( + 'airflow/confirm.html', + endpoint=None, + message="Task instances you are about to clear:", + details=details, + ) + + return response @expose('/clear', methods=['POST']) @auth.has_access( @@ -2023,11 +2026,9 @@ class Airflow(AirflowBaseView): task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) dag = current_app.dag_bag.get_dag(dag_id) - - if 'map_index' not in request.form: - map_indexes: Optional[List[int]] = None - else: - map_indexes = request.form.getlist('map_index', type=int) + map_indexes = request.form.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) @@ -2047,17 +2048,12 @@ class Airflow(AirflowBaseView): end_date = execution_date if not future else None start_date = execution_date if not past else None - if map_indexes is None: - task_ids: Union[List[str], List[Tuple[str, int]]] = [task_id] - else: - task_ids = [(task_id, map_index) for map_index in map_indexes] - return self._clear_dag_tis( dag, start_date, end_date, origin, - task_ids=task_ids, + map_indexes=map_indexes, recursive=recursive, confirmed=confirmed, only_failed=only_failed, @@ -2076,6 +2072,9 @@ class Airflow(AirflowBaseView): dag_id = request.form.get('dag_id') dag_run_id = request.form.get('dag_run_id') confirmed = request.form.get('confirmed') == "true" + map_indexes = request.form.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) dag = current_app.dag_bag.get_dag(dag_id) dr = dag.get_dagrun(run_id=dag_run_id) @@ -2086,6 +2085,7 @@ class Airflow(AirflowBaseView): dag, start_date, end_date, + map_indexes=map_indexes, origin=None, recursive=True, confirmed=confirmed, @@ -2290,28 +2290,28 @@ class Airflow(AirflowBaseView): def _mark_task_instance_state( self, - *, - dag_id: str, - run_id: str, - task_id: str, - map_indexes: Optional[List[int]], - origin: str, - upstream: bool, - downstream: bool, - future: bool, - past: bool, - state: TaskInstanceState, + dag_id, + task_id, + map_indexes, + origin, + dag_run_id, + upstream, + downstream, + future, + past, + state, ): - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) + latest_execution_date = dag.get_latest_execution_date() - if not run_id: - flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error") + if not latest_execution_date: + flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error") return redirect(origin) altered = dag.set_task_instance_state( task_id=task_id, - map_indexes=map_indexes, - run_id=run_id, + map_index=map_indexes, + run_id=dag_run_id, state=state, upstream=upstream, downstream=downstream, @@ -2338,11 +2338,9 @@ class Airflow(AirflowBaseView): dag_run_id = args.get('dag_run_id') state = args.get('state') origin = args.get('origin') - - if 'map_index' not in args: - map_indexes: Optional[List[int]] = None - else: - map_indexes = args.getlist('map_index', type=int) + map_indexes = args.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2375,13 +2373,11 @@ class Airflow(AirflowBaseView): msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run" return redirect_or_json(origin, msg, status='error') - if map_indexes is None: - tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task] - else: - tasks = [(task, map_index) for map_index in map_indexes] + from airflow.api.common.mark_tasks import set_state to_be_altered = set_state( - tasks=tasks, + tasks=[task], + map_indexes=map_indexes, run_id=dag_run_id, upstream=upstream, downstream=downstream, @@ -2419,30 +2415,28 @@ class Airflow(AirflowBaseView): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - run_id = args.get('dag_run_id') - - if 'map_index' not in args: - map_indexes: Optional[List[int]] = None - else: - map_indexes = args.getlist('map_index', type=int) - origin = get_safe_url(args.get('origin')) + dag_run_id = args.get('dag_run_id') + map_indexes = args.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id=dag_id, - run_id=run_id, - task_id=task_id, - map_indexes=map_indexes, - origin=origin, - upstream=upstream, - downstream=downstream, - future=future, - past=past, - state=TaskInstanceState.FAILED, + dag_id, + task_id, + map_indexes, + origin, + dag_run_id, + upstream, + downstream, + future, + past, + State.FAILED, ) @expose('/success', methods=['POST']) @@ -2458,30 +2452,28 @@ class Airflow(AirflowBaseView): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - run_id = args.get('dag_run_id') - - if 'map_index' not in args: - map_indexes: Optional[List[int]] = None - else: - map_indexes = args.getlist('map_index', type=int) - origin = get_safe_url(args.get('origin')) + dag_run_id = args.get('dag_run_id') + map_indexes = args.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id=dag_id, - run_id=run_id, - task_id=task_id, - map_indexes=map_indexes, - origin=origin, - upstream=upstream, - downstream=downstream, - future=future, - past=past, - state=TaskInstanceState.SUCCESS, + dag_id, + task_id, + map_indexes, + origin, + dag_run_id, + upstream, + downstream, + future, + past, + State.SUCCESS, ) @expose('/dags/<string:dag_id>')
