This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 9cf5f6f084 Refactor _manage_executor_state by refreshing TIs in batch (#36502) 9cf5f6f084 is described below commit 9cf5f6f08483ff141df51c07daa91a0aa34906ec Author: Hussein Awala <huss...@awala.fr> AuthorDate: Sat Dec 30 23:08:52 2023 +0100 Refactor _manage_executor_state by refreshing TIs in batch (#36502) Refactor _manage_executor_state by refreshing TIs in batch (#36418)" (#36500)" Handle Microsoft SQL Server --- airflow/jobs/backfill_job_runner.py | 42 +++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/backfill_job_runner.py b/airflow/jobs/backfill_job_runner.py index 92ade82ff9..af8ebd1c4e 100644 --- a/airflow/jobs/backfill_job_runner.py +++ b/airflow/jobs/backfill_job_runner.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence import attr import pendulum -from sqlalchemy import select, update +from sqlalchemy import select, tuple_, update from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient from tabulate import tabulate @@ -264,16 +264,46 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin): :return: An iterable of expanded TaskInstance per MappedTask """ executor = self.job.executor + # list of tuples (dag_id, task_id, execution_date, map_index) of running tasks in executor + buffered_events = list(executor.get_event_buffer().items()) + if session.get_bind().dialect.name == "mssql": + # SQL Server doesn't support multiple column subqueries + # TODO: Remove this once we drop support for SQL Server (#35868) + need_refresh = True + running_dict = {(ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in running.values()} + else: + running_tis_ids = [ + (key.dag_id, key.task_id, key.run_id, key.map_index) + for key, _ in buffered_events + if key in running + ] + # list of TaskInstance of running tasks in executor (refreshed from db in batch) + refreshed_running_tis = session.scalars( + select(TaskInstance).where( + tuple_( + TaskInstance.dag_id, + TaskInstance.task_id, + TaskInstance.run_id, + TaskInstance.map_index, + ).in_(running_tis_ids) + ) + ).all() + # dict of refreshed TaskInstance by key to easily find them + running_dict = { + (ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in refreshed_running_tis + } + need_refresh = False - # TODO: query all instead of refresh from db - for key, value in list(executor.get_event_buffer().items()): + for key, value in buffered_events: state, info = value - if key not in running: + ti_key = (key.dag_id, key.task_id, key.run_id, key.map_index) + if ti_key not in running_dict: self.log.warning("%s state %s not in running=%s", key, state, running.values()) continue - ti = running[key] - ti.refresh_from_db() + ti = running_dict[ti_key] + if need_refresh: + ti.refresh_from_db(session=session) self.log.debug("Executor state: %s task %s", state, ti)