This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6c4788e21dd49593cd61c4d84e6e09dc6d89bb0b
Author: Ash Berlin-Taylor <a...@apache.org>
AuthorDate: Fri May 6 17:02:27 2022 +0100

    Change approach to finding bad rows to LEFT OUTER JOIN. (#23528)
    
    Rather than sub-selects (two for count, or one for the CREATE TABLE).
    
    For a _large_ database (27m TaskInstances, 2m DagRuns) this takes the
    time from 10minutes to around 3 minutes per table (we have 3) down to 3
    minutes per table. (All times on Postgres.)
    
    Before:
    
    ```sql
    CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS
    SELECT
      rendered_task_instance_fields.dag_id AS dag_id,
      rendered_task_instance_fields.task_id AS task_id,
      rendered_task_instance_fields.execution_date AS execution_date,
      rendered_task_instance_fields.rendered_fields AS rendered_fields,
      rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml +
    FROM
      rendered_task_instance_fields
    WHERE
      NOT (
        EXISTS (
          SELECT
            1
          FROM
            task_instance
            JOIN dag_run ON dag_run.dag_id = task_instance.dag_id
            AND dag_run.run_id = task_instance.run_id
          WHERE
            rendered_task_instance_fields.dag_id = task_instance.dag_id
            AND rendered_task_instance_fields.task_id = task_instance.task_id
            AND rendered_task_instance_fields.execution_date = 
dag_run.execution_date
        )
      )
    ```
    
    After:
    
    ```sql
    CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS
    SELECT
      rendered_task_instance_fields.dag_id AS dag_id,
      rendered_task_instance_fields.task_id AS task_id,
      rendered_task_instance_fields.execution_date AS execution_date,
      rendered_task_instance_fields.rendered_fields AS rendered_fields,
      rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml +
    FROM
      rendered_task_instance_fields
      LEFT OUTER JOIN dag_run ON rendered_task_instance_fields.dag_id = 
dag_run.dag_id
      AND rendered_task_instance_fields.execution_date = dag_run.execution_date
      LEFT OUTER JOIN task_instance ON dag_run.dag_id = task_instance.dag_id
      AND dag_run.run_id = task_instance.run_id
      AND rendered_task_instance_fields.task_id = task_instance.task_id
    WHERE
      task_instance.dag_id IS NULL
      OR dag_run.dag_id IS NULL
    ;
    ```
    
    (cherry picked from commit 22a9293ff8f48411d39074d9bc88af35abe9850f)
---
 airflow/utils/db.py | 73 +++++++++++++++++++++++++++++------------------------
 1 file changed, 40 insertions(+), 33 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 7325c0e243..730757694d 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -26,7 +26,7 @@ from dataclasses import dataclass
 from tempfile import gettempdir
 from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, 
Union
 
-from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, 
table, text
+from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, 
table, text, tuple_
 from sqlalchemy.orm.session import Session
 
 import airflow
@@ -1047,7 +1047,7 @@ def _create_table_as(
 
 
 def _move_dangling_data_to_new_table(
-    session, source_table: "Table", source_query: "Query", exists_subquery, 
target_table_name: str
+    session, source_table: "Table", source_query: "Query", target_table_name: 
str
 ):
 
     bind = session.get_bind()
@@ -1072,11 +1072,16 @@ def _move_dangling_data_to_new_table(
 
     if not first_moved_row:
         log.debug("no rows moved; dropping %s", target_table_name)
+        # no bad rows were found; drop moved rows table.
         target_table.drop(bind=session.get_bind(), checkfirst=True)
     else:
         log.debug("rows moved; purging from %s", source_table.name)
         if dialect_name == 'sqlite':
-            delete = source_table.delete().where(~exists_subquery.exists())
+            pk_cols = source_table.primary_key.columns
+
+            delete = source_table.delete().where(
+                
tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery())
+            )
         else:
             delete = source_table.delete().where(
                 and_(col == target_table.c[col.name] for col in 
source_table.primary_key.columns)
@@ -1088,7 +1093,7 @@ def _move_dangling_data_to_new_table(
     log.debug("exiting move function")
 
 
-def _dag_run_exists(session, source_table, dag_run):
+def _dangling_against_dag_run(session, source_table, dag_run):
     """
     Given a source table, we generate a subquery that will return 1 for every 
row that
     has a dagrun.
@@ -1097,11 +1102,14 @@ def _dag_run_exists(session, source_table, dag_run):
         source_table.c.dag_id == dag_run.c.dag_id,
         source_table.c.execution_date == dag_run.c.execution_date,
     )
-    exists_subquery = 
session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond)
-    return exists_subquery
+    return (
+        session.query(*[c.label(c.name) for c in source_table.c])
+        .join(dag_run, source_to_dag_run_join_cond, isouter=True)
+        .filter(dag_run.c.dag_id.is_(None))
+    )
 
 
-def _task_instance_exists(session, source_table, dag_run, task_instance):
+def _dangling_against_task_instance(session, source_table, dag_run, 
task_instance):
     """
     Given a source table, we generate a subquery that will return 1 for every 
row that
     has a valid task instance (and associated dagrun).
@@ -1114,32 +1122,33 @@ def _task_instance_exists(session, source_table, 
dag_run, task_instance):
     """
     if 'run_id' not in task_instance.c:
         # db is < 2.2.0
-        where_clause = and_(
-            source_table.c.dag_id == task_instance.c.dag_id,
-            source_table.c.task_id == task_instance.c.task_id,
-            source_table.c.execution_date == task_instance.c.execution_date,
+        dr_join_cond = and_(
+            source_table.c.dag_id == dag_run.c.dag_id,
+            source_table.c.execution_date == dag_run.c.execution_date,
         )
-        ti_to_dr_join_cond = and_(
+        ti_join_cond = and_(
             dag_run.c.dag_id == task_instance.c.dag_id,
             dag_run.c.execution_date == task_instance.c.execution_date,
+            source_table.c.task_id == task_instance.c.task_id,
         )
     else:
         # db is 2.2.0 <= version < 2.3.0
-        where_clause = and_(
-            source_table.c.dag_id == task_instance.c.dag_id,
-            source_table.c.task_id == task_instance.c.task_id,
+        dr_join_cond = and_(
+            source_table.c.dag_id == dag_run.c.dag_id,
             source_table.c.execution_date == dag_run.c.execution_date,
         )
-        ti_to_dr_join_cond = and_(
+        ti_join_cond = and_(
             dag_run.c.dag_id == task_instance.c.dag_id,
             dag_run.c.run_id == task_instance.c.run_id,
+            source_table.c.task_id == task_instance.c.task_id,
         )
-    exists_subquery = (
-        session.query(text('1'))
-        .select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond))
-        .filter(where_clause)
+
+    return (
+        session.query(*[c.label(c.name) for c in source_table.c])
+        .join(dag_run, dr_join_cond, isouter=True)
+        .join(task_instance, ti_join_cond, isouter=True)
+        .filter(or_(task_instance.c.dag_id.is_(None), 
dag_run.c.dag_id.is_(None)))
     )
-    return exists_subquery
 
 
 def _move_duplicate_data_to_new_table(
@@ -1207,23 +1216,23 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
     @dataclass
     class BadReferenceConfig:
         """
-        :param exists_func: function that returns subquery which determines 
whether bad rows exist
+        :param bad_rows_func: function that returns subquery which determines 
whether bad rows exist
         :param join_tables: table objects referenced in subquery
         :param ref_table: information-only identifier for categorizing the 
missing ref
         """
 
-        exists_func: Callable
+        bad_rows_func: Callable
         join_tables: List[str]
         ref_table: str
 
     missing_dag_run_config = BadReferenceConfig(
-        exists_func=_dag_run_exists,
+        bad_rows_func=_dangling_against_dag_run,
         join_tables=['dag_run'],
         ref_table='dag_run',
     )
 
     missing_ti_config = BadReferenceConfig(
-        exists_func=_task_instance_exists,
+        bad_rows_func=_dangling_against_task_instance,
         join_tables=['dag_run', 'task_instance'],
         ref_table='task_instance',
     )
@@ -1238,7 +1247,8 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
     metadata = reflect_tables([*[x[0] for x in models_list], DagRun, 
TaskInstance], session)
 
     if (
-        metadata.tables.get(DagRun.__tablename__) is None
+        not metadata.tables
+        or metadata.tables.get(DagRun.__tablename__) is None
         or metadata.tables.get(TaskInstance.__tablename__) is None
     ):
         # Key table doesn't exist -- likely empty DB.
@@ -1251,7 +1261,6 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
         log.debug("checking model %s", model.__tablename__)
         # We can't use the model here since it may differ from the db state 
due to
         # this function is run prior to migration. Use the reflected table 
instead.
-        exists_func_kwargs = {x: metadata.tables[x] for x in 
bad_ref_cfg.join_tables}
         source_table = metadata.tables.get(model.__tablename__)  # type: ignore
         if source_table is None:
             continue
@@ -1260,13 +1269,12 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
         if "run_id" in source_table.columns:
             continue
 
-        bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, 
**exists_func_kwargs)
-        select_list = [x.label(x.name) for x in source_table.c]
-        invalid_rows_query = 
session.query(*select_list).filter(~bad_rows_subquery.exists())
+        func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
+        bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, 
**func_kwargs)
 
         dangling_table_name = 
_format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
         if dangling_table_name in existing_table_names:
-            invalid_row_count = invalid_rows_query.count()
+            invalid_row_count = bad_rows_query.count()
             if invalid_row_count <= 0:
                 continue
             else:
@@ -1283,8 +1291,7 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
         _move_dangling_data_to_new_table(
             session,
             source_table,
-            invalid_rows_query,
-            bad_rows_subquery,
+            bad_rows_query,
             dangling_table_name,
         )
 

Reply via email to