This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 72f43fcc83 Revert "Refactor _manage_executor_state by refreshing TIs 
in batch (#36418)" (#36500)
72f43fcc83 is described below

commit 72f43fcc838afc1c95b85dcb27af6519483ef64b
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Sat Dec 30 19:12:21 2023 +0100

    Revert "Refactor _manage_executor_state by refreshing TIs in batch 
(#36418)" (#36500)
    
    This reverts commit 9d45db9e2cca2ad04db72f7e0712c478e5a8e1f1.
    
    t#
---
 airflow/jobs/backfill_job_runner.py | 33 ++++++---------------------------
 1 file changed, 6 insertions(+), 27 deletions(-)

diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index 4a1643e4c4..92ade82ff9 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator, 
Mapping, Sequence
 
 import attr
 import pendulum
-from sqlalchemy import select, tuple_, update
+from sqlalchemy import select, update
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import make_transient
 from tabulate import tabulate
@@ -264,37 +264,16 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
         :return: An iterable of expanded TaskInstance per MappedTask
         """
         executor = self.job.executor
-        # list of tuples (dag_id, task_id, execution_date, map_index) of 
running tasks in executor
-        buffered_events = list(executor.get_event_buffer().items())
-        running_tis_ids = [
-            (key.dag_id, key.task_id, key.run_id, key.map_index)
-            for key, _ in buffered_events
-            if key in running
-        ]
-        # list of TaskInstance of running tasks in executor (refreshed from db 
in batch)
-        refreshed_running_tis = session.scalars(
-            select(TaskInstance).where(
-                tuple_(
-                    TaskInstance.dag_id,
-                    TaskInstance.task_id,
-                    TaskInstance.run_id,
-                    TaskInstance.map_index,
-                ).in_(running_tis_ids)
-            )
-        ).all()
-        # dict of refreshed TaskInstance by key to easily find them
-        refreshed_running_tis_dict = {
-            (ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in 
refreshed_running_tis
-        }
 
-        for key, value in buffered_events:
+        # TODO: query all instead of refresh from db
+        for key, value in list(executor.get_event_buffer().items()):
             state, info = value
-            ti_key = (key.dag_id, key.task_id, key.run_id, key.map_index)
-            if ti_key not in refreshed_running_tis_dict:
+            if key not in running:
                 self.log.warning("%s state %s not in running=%s", key, state, 
running.values())
                 continue
 
-            ti = refreshed_running_tis_dict[ti_key]
+            ti = running[key]
+            ti.refresh_from_db()
 
             self.log.debug("Executor state: %s task %s", state, ti)
 

Reply via email to