This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b99d1cd5d3 Respect max_active_runs for dataset-triggered dags (#26348)
b99d1cd5d3 is described below

commit b99d1cd5d32aea5721c512d6052b6b7b3e0dfefb
Author: Daniel Standish <15932138+dstand...@users.noreply.github.com>
AuthorDate: Wed Sep 14 05:28:30 2022 -0700

    Respect max_active_runs for dataset-triggered dags (#26348)
    
    Co-authored-by: Ash Berlin-Taylor <ash_git...@firemirror.com>
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
---
 airflow/models/dag.py    | 31 +++++++++++++++++++++------
 airflow/models/dagrun.py |  6 ++++++
 airflow/www/views.py     |  1 +
 tests/models/test_dag.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++-
 4 files changed, 85 insertions(+), 8 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b5542414fa..60660ce0fb 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -54,6 +54,7 @@ import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
 from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, 
Text, and_, case, func, not_, or_
+from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -3066,6 +3067,7 @@ class DagModel(Base):
         "DagScheduleDatasetReference",
         cascade='all, delete, delete-orphan',
     )
+    schedule_datasets = association_proxy('schedule_dataset_references', 
'dataset')
     task_outlet_dataset_references = relationship(
         "TaskOutletDatasetReference",
         cascade='all, delete, delete-orphan',
@@ -3235,7 +3237,7 @@ class DagModel(Base):
         transaction is committed it will be unlocked.
         """
         # these dag ids are triggered by datasets, and they are ready to go.
-        dataset_triggered_dag_info_list = {
+        dataset_triggered_dag_info = {
             x.dag_id: (x.first_queued_time, x.last_queued_time)
             for x in session.query(
                 DagScheduleDatasetReference.dag_id,
@@ -3247,12 +3249,27 @@ class DagModel(Base):
             .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
             .all()
         }
-        dataset_triggered_dag_ids = 
list(dataset_triggered_dag_info_list.keys())
-
-        # TODO[HA]: Bake this query, it is run _A lot_
-        # We limit so that _one_ scheduler doesn't try to do all the creation
-        # of dag runs
+        dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
+        if dataset_triggered_dag_ids:
+            exclusion_list = {
+                x.dag_id
+                for x in (
+                    session.query(DagModel.dag_id)
+                    .join(DagRun.dag_model)
+                    .filter(DagRun.state.in_((DagRunState.QUEUED, 
DagRunState.RUNNING)))
+                    .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+                    .group_by(DagModel.dag_id)
+                    .having(func.count() >= func.max(DagModel.max_active_runs))
+                    .all()
+                )
+            }
+            if exclusion_list:
+                dataset_triggered_dag_ids -= exclusion_list
+                dataset_triggered_dag_info = {
+                    k: v for k, v in dataset_triggered_dag_info.items() if k 
not in exclusion_list
+                }
 
+        # We limit so that _one_ scheduler doesn't try to do all the creation 
of dag runs
         query = (
             session.query(cls)
             .filter(
@@ -3270,7 +3287,7 @@ class DagModel(Base):
 
         return (
             with_row_locks(query, of=cls, session=session, 
**skip_locked(session=session)),
-            dataset_triggered_dag_info_list,
+            dataset_triggered_dag_info,
         )
 
     def calculate_dagrun_date_fields(
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 4e8fe00dee..99d969ab19 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -163,6 +163,12 @@ class DagRun(Base, LoggingMixin):
     task_instances = relationship(
         TI, back_populates="dag_run", cascade='save-update, merge, delete, 
delete-orphan'
     )
+    dag_model = relationship(
+        "DagModel",
+        primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id",
+        uselist=False,
+        viewonly=True,
+    )
 
     DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
         'scheduler',
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 76dea597b8..21d8bdcaab 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1145,6 +1145,7 @@ class Airflow(AirflowBaseView):
         owner_links = 
session.query(DagOwnerAttributes).filter_by(dag_id=dag_id).all()
 
         attrs_to_avoid = [
+            "schedule_datasets",
             "schedule_dataset_references",
             "task_outlet_dataset_references",
             "NUM_DAGS_PER_DAGRUN_QUERY",
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 1954377ef5..62f057c376 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -65,7 +65,7 @@ from airflow.utils.types import DagRunType
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
 from tests.test_utils.asserts import assert_queries_count
-from tests.test_utils.db import clear_db_dags, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs
 from tests.test_utils.mapping import expand_mapped_task
 from tests.test_utils.timetables import cron_timetable, delta_timetable
 
@@ -2102,6 +2102,17 @@ class TestDag:
 
 
 class TestDagModel:
+    def _clean(self):
+        clear_db_dags()
+        clear_db_datasets()
+        clear_db_runs()
+
+    def setup_method(self):
+        self._clean()
+
+    def teardown_method(self):
+        self._clean()
+
     def test_dags_needing_dagruns_not_too_early(self):
         dag = DAG(dag_id='far_future_dag', start_date=timezone.datetime(2038, 
1, 1))
         EmptyOperator(task_id='dummy', dag=dag, owner='airflow')
@@ -2125,6 +2136,48 @@ class TestDagModel:
         session.rollback()
         session.close()
 
+    def test_dags_needing_dagruns_datasets(self, dag_maker, session):
+        dataset = Dataset(uri='hello')
+        with dag_maker(
+            session=session,
+            dag_id='my_dag',
+            max_active_runs=1,
+            schedule=[dataset],
+            start_date=pendulum.now().add(days=-2),
+        ) as dag:
+            EmptyOperator(task_id='dummy')
+
+        # there's no queue record yet, so no runs needed at this time.
+        query, _ = DagModel.dags_needing_dagruns(session)
+        dag_models = query.all()
+        assert dag_models == []
+
+        # add queue records so we'll need a run
+        dag_model = session.query(DagModel).filter(DagModel.dag_id == 
dag.dag_id).one()
+        dataset_model: DatasetModel = dag_model.schedule_datasets[0]
+        session.add(DatasetDagRunQueue(dataset_id=dataset_model.id, 
target_dag_id=dag_model.dag_id))
+        session.flush()
+        query, _ = DagModel.dags_needing_dagruns(session)
+        dag_models = query.all()
+        assert dag_models == [dag_model]
+
+        # create run so we don't need a run anymore (due to max active runs)
+        dag_maker.create_dagrun(
+            run_type=DagRunType.DATASET_TRIGGERED,
+            state=DagRunState.QUEUED,
+            execution_date=pendulum.now('UTC'),
+        )
+        query, _ = DagModel.dags_needing_dagruns(session)
+        dag_models = query.all()
+        assert dag_models == []
+
+        # increase max active runs and we should now need another run
+        dag_maker.dag_model.max_active_runs = 2
+        session.flush()
+        query, _ = DagModel.dags_needing_dagruns(session)
+        dag_models = query.all()
+        assert dag_models == [dag_model]
+
     def test_max_active_runs_not_none(self):
         dag = DAG(dag_id='test_max_active_runs_not_none', 
start_date=timezone.datetime(2038, 1, 1))
         EmptyOperator(task_id='dummy', dag=dag, owner='airflow')

Reply via email to