ashb commented on code in PR #66878:
URL: https://github.com/apache/airflow/pull/66878#discussion_r3460125917
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -899,6 +836,126 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
stats.gauge("scheduler.tasks.executable", len(executable_tis))
+ return self._mark_task_instances_queued(executable_tis, session)
+
+ def _build_schedulable_tis_query(
+ self,
+ starved_pools: set[str],
+ starved_dags: set[str],
+ starved_tasks: set[tuple[str, str]],
+ starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]],
+ max_tis: int,
+ ) -> Select[tuple[TI]]:
+ """
+ Build a query that fetches SCHEDULED TIs eligible for execution this
cycle.
+
+ Applies current starvation exclusions so that saturated pools, DAGs,
or tasks
+ don't re-appear in the candidate set. Row-number windowing enforces
+ ``max_active_tasks`` per DagRun. The returned query is ready to be
wrapped
+ with ``with_row_locks`` and executed by the caller; no session is
required here.
+
+ This behaves the same as calling ``concurrency_map.load()`` followed by
+ ``_get_current_dr_task_concurrency``, with the difference that the
subquery
+ object is built here and executed as part of the main query, so any
state
+ changes between construction and execution are naturally ignored.
+ """
+ dr_task_concurrency_subquery =
_get_current_dr_task_concurrency(states=EXECUTION_STATES)
+
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .join(
+ dr_task_concurrency_subquery,
+ and_(
+ TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+ TI.run_id == dr_task_concurrency_subquery.c.run_id,
+ ),
+ isouter=True,
+ )
+
.where(func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) <
DM.max_active_tasks)
+ .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
+ )
+
+ # Starvation filters should be applied before computing the row_num
based on the
+ # max_active_tasks limit. That way, starved dags and tasks that
shouldn't run,
+ # won't occupy a slot.
+ if starved_pools:
+ query = query.where(TI.pool.not_in(starved_pools))
+
+ if starved_dags:
+ query = query.where(TI.dag_id.not_in(starved_dags))
+
+ if starved_tasks:
+ query = query.where(tuple_(TI.dag_id,
TI.task_id).not_in(starved_tasks))
+
+ if starved_tasks_task_dagrun_concurrency:
+ query = query.where(
+ tuple_(TI.dag_id, TI.run_id,
TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
+ )
+
+ # Create a subquery with row numbers partitioned by dag_id and run_id.
+ # Different dags can have the same run_id but
+ # the dag_id combined with the run_id uniquely identify a run.
+ ranked_query = (
+ query.add_columns(
+ func.row_number()
+ .over(
+ partition_by=[TI.dag_id, TI.run_id],
+ order_by=[-TI.priority_weight, DR.logical_date,
TI.map_index],
+ )
+ .label("row_num"),
+ DM.max_active_tasks.label("dr_max_active_tasks"),
+ # Create columns for the order_by checks here for sqlite.
+ TI.priority_weight.label("priority_weight_for_ordering"),
+ DR.logical_date.label("logical_date_for_ordering"),
+ TI.map_index.label("map_index_for_ordering"),
+ )
+ ).subquery()
+
+ # Select only rows where row_number <= max_active_tasks.
+ return (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .select_from(ranked_query)
+ .join(
+ TI,
+ (TI.dag_id == ranked_query.c.dag_id)
+ & (TI.task_id == ranked_query.c.task_id)
+ & (TI.run_id == ranked_query.c.run_id)
+ & (TI.map_index == ranked_query.c.map_index),
+ )
+ .where(ranked_query.c.row_num <=
ranked_query.c.dr_max_active_tasks)
+ # Add the order_by columns from the ranked query for sqlite.
+ .order_by(
+ -ranked_query.c.priority_weight_for_ordering,
+ ranked_query.c.logical_date_for_ordering,
+ ranked_query.c.map_index_for_ordering,
+ )
+ .options(selectinload(TI.dag_model))
+ .limit(max_tis)
+ )
+
+ def _mark_task_instances_queued(self, executable_tis: list[TI], session:
Session) -> list[TI]:
Review Comment:
```suggestion
def _mark_task_instances_queued(self, executable_tis: list[TI], *,
session: Session) -> list[TI]:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]