This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 54f23231be9a518ff6071868c19034b49e18f1b2
Author: Aleksandr Artemenkov <[email protected]>
AuthorDate: Tue Oct 3 10:33:08 2023 +0300

    Fixed rows count in the migration script (#34348)
    
    * Fixed row count for SQLAlchemy 1.4+
    
    * Updated newsfragments
    
    * Fixed typo
    
    * Added newline
    
    * Added test for `check_bad_references`
    
    (cherry picked from commit f349fda125c2251ac4129c2c28fbf6f7dbb69294)
---
 airflow/utils/db.py            |  2 +-
 newsfragments/34348.bugfix.rst |  1 +
 tests/utils/test_db.py         | 86 +++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 86 insertions(+), 3 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 34cd1c1bb4..80dd788688 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -1446,7 +1446,7 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
 
         dangling_table_name = 
_format_airflow_moved_table_name(source_table.name, change_version, "dangling")
         if dangling_table_name in existing_table_names:
-            invalid_row_count = bad_rows_query.count()
+            invalid_row_count = get_query_count(bad_rows_query, 
session=session)
             if invalid_row_count:
                 yield _format_dangling_error(
                     source_table=source_table.name,
diff --git a/newsfragments/34348.bugfix.rst b/newsfragments/34348.bugfix.rst
new file mode 100644
index 0000000000..c9f27e42f2
--- /dev/null
+++ b/newsfragments/34348.bugfix.rst
@@ -0,0 +1 @@
+Fixed ``AttributeError: 'Select' object has no attribute 'count'`` during the 
``airflow db migrate`` command
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 915e4b5238..aa4ad1d89d 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -31,13 +31,15 @@ from alembic.config import Config
 from alembic.migration import MigrationContext
 from alembic.runtime.environment import EnvironmentContext
 from alembic.script import ScriptDirectory
-from sqlalchemy import MetaData
+from sqlalchemy import MetaData, Table
+from sqlalchemy.sql import Select
 
 from airflow.exceptions import AirflowException
 from airflow.models import Base as airflow_base
 from airflow.settings import engine
 from airflow.utils.db import (
     _get_alembic_config,
+    check_bad_references,
     check_migrations,
     compare_server_default,
     compare_type,
@@ -49,6 +51,7 @@ from airflow.utils.db import (
     resetdb,
     upgradedb,
 )
+from airflow.utils.session import NEW_SESSION
 
 
 class TestDb:
@@ -57,7 +60,7 @@ class TestDb:
 
         airflow.models.import_all_models()
         all_meta_data = MetaData()
-        for (table_name, table) in airflow_base.metadata.tables.items():
+        for table_name, table in airflow_base.metadata.tables.items():
             all_meta_data._add_table(table_name, table.schema, table)
 
         # create diff between database schema and SQLAlchemy model
@@ -251,3 +254,82 @@ class TestDb:
         import airflow
 
         assert config.config_file_name == 
os.path.join(os.path.dirname(airflow.__file__), "alembic.ini")
+
+    @mock.patch("airflow.utils.db._move_dangling_data_to_new_table")
+    @mock.patch("airflow.utils.db.get_query_count")
+    @mock.patch("airflow.utils.db._dangling_against_task_instance")
+    @mock.patch("airflow.utils.db._dangling_against_dag_run")
+    @mock.patch("airflow.utils.db.reflect_tables")
+    @mock.patch("airflow.utils.db.inspect")
+    def test_check_bad_references(
+        self,
+        mock_inspect: MagicMock,
+        mock_reflect_tables: MagicMock,
+        mock_dangling_against_dag_run: MagicMock,
+        mock_dangling_against_task_instance: MagicMock,
+        mock_get_query_count: MagicMock,
+        mock_move_dangling_data_to_new_table: MagicMock,
+    ):
+        from airflow.models.dagrun import DagRun
+        from airflow.models.renderedtifields import RenderedTaskInstanceFields
+        from airflow.models.taskfail import TaskFail
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.models.taskreschedule import TaskReschedule
+        from airflow.models.xcom import XCom
+
+        mock_session = MagicMock(spec=NEW_SESSION)
+        mock_bind = MagicMock()
+        mock_session.get_bind.return_value = mock_bind
+        task_instance_table = MagicMock(spec=Table)
+        task_instance_table.name = TaskInstance.__tablename__
+        dag_run_table = MagicMock(spec=Table)
+        task_fail_table = MagicMock(spec=Table)
+        task_fail_table.name = TaskFail.__tablename__
+
+        mock_reflect_tables.return_value = MagicMock(
+            tables={
+                DagRun.__tablename__: dag_run_table,
+                TaskInstance.__tablename__: task_instance_table,
+                TaskFail.__tablename__: task_fail_table,
+            }
+        )
+
+        # Simulate that there is a moved `task_instance` table from the
+        # previous run, but no moved `task_fail` table
+        dangling_task_instance_table_name = 
f"_airflow_moved__2_2__dangling__{task_instance_table.name}"
+        dangling_task_fail_table_name = 
f"_airflow_moved__2_3__dangling__{task_fail_table.name}"
+        mock_get_table_names = MagicMock(
+            return_value=[
+                TaskInstance.__tablename__,
+                DagRun.__tablename__,
+                TaskFail.__tablename__,
+                dangling_task_instance_table_name,
+            ]
+        )
+        mock_inspect.return_value = MagicMock(
+            get_table_names=mock_get_table_names,
+        )
+        mock_select = MagicMock(spec=Select)
+        mock_dangling_against_dag_run.return_value = mock_select
+        mock_dangling_against_task_instance.return_value = mock_select
+        mock_get_query_count.return_value = 1
+
+        # Should return a single error related to the dangling `task_instance` 
table
+        errs = list(check_bad_references(session=mock_session))
+        assert len(errs) == 1
+        assert dangling_task_instance_table_name in errs[0]
+
+        mock_reflect_tables.assert_called_once_with(
+            [TaskInstance, TaskReschedule, RenderedTaskInstanceFields, 
TaskFail, XCom, DagRun, TaskInstance],
+            mock_session,
+        )
+        mock_inspect.assert_called_once_with(mock_bind)
+        mock_get_table_names.assert_called_once()
+        mock_dangling_against_dag_run.assert_called_once_with(
+            mock_session, task_instance_table, dag_run=dag_run_table
+        )
+        mock_get_query_count.assert_called_once_with(mock_select, 
session=mock_session)
+        mock_move_dangling_data_to_new_table.assert_called_once_with(
+            mock_session, task_fail_table, mock_select, 
dangling_task_fail_table_name
+        )
+        mock_session.rollback.assert_called_once()

Reply via email to