This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9cf5f6f084 Refactor _manage_executor_state by refreshing TIs in batch 
(#36502)
9cf5f6f084 is described below

commit 9cf5f6f08483ff141df51c07daa91a0aa34906ec
Author: Hussein Awala <huss...@awala.fr>
AuthorDate: Sat Dec 30 23:08:52 2023 +0100

    Refactor _manage_executor_state by refreshing TIs in batch (#36502)
    
    Refactor _manage_executor_state by refreshing TIs in batch (#36418)" 
(#36500)"
    
    Handle Microsoft SQL Server
---
 airflow/jobs/backfill_job_runner.py | 42 +++++++++++++++++++++++++++++++------
 1 file changed, 36 insertions(+), 6 deletions(-)

diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index 92ade82ff9..af8ebd1c4e 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator, 
Mapping, Sequence
 
 import attr
 import pendulum
-from sqlalchemy import select, update
+from sqlalchemy import select, tuple_, update
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import make_transient
 from tabulate import tabulate
@@ -264,16 +264,46 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
         :return: An iterable of expanded TaskInstance per MappedTask
         """
         executor = self.job.executor
+        # list of tuples (dag_id, task_id, execution_date, map_index) of 
running tasks in executor
+        buffered_events = list(executor.get_event_buffer().items())
+        if session.get_bind().dialect.name == "mssql":
+            # SQL Server doesn't support multiple column subqueries
+            # TODO: Remove this once we drop support for SQL Server (#35868)
+            need_refresh = True
+            running_dict = {(ti.dag_id, ti.task_id, ti.run_id, ti.map_index): 
ti for ti in running.values()}
+        else:
+            running_tis_ids = [
+                (key.dag_id, key.task_id, key.run_id, key.map_index)
+                for key, _ in buffered_events
+                if key in running
+            ]
+            # list of TaskInstance of running tasks in executor (refreshed 
from db in batch)
+            refreshed_running_tis = session.scalars(
+                select(TaskInstance).where(
+                    tuple_(
+                        TaskInstance.dag_id,
+                        TaskInstance.task_id,
+                        TaskInstance.run_id,
+                        TaskInstance.map_index,
+                    ).in_(running_tis_ids)
+                )
+            ).all()
+            # dict of refreshed TaskInstance by key to easily find them
+            running_dict = {
+                (ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in 
refreshed_running_tis
+            }
+            need_refresh = False
 
-        # TODO: query all instead of refresh from db
-        for key, value in list(executor.get_event_buffer().items()):
+        for key, value in buffered_events:
             state, info = value
-            if key not in running:
+            ti_key = (key.dag_id, key.task_id, key.run_id, key.map_index)
+            if ti_key not in running_dict:
                 self.log.warning("%s state %s not in running=%s", key, state, 
running.values())
                 continue
 
-            ti = running[key]
-            ti.refresh_from_db()
+            ti = running_dict[ti_key]
+            if need_refresh:
+                ti.refresh_from_db(session=session)
 
             self.log.debug("Executor state: %s task %s", state, ti)
 

Reply via email to