This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2950fd7685 [Models] [Postgres] Check if the dynamically-added index is 
in the table schema before adding (#32731)
2950fd7685 is described below

commit 2950fd768541fc902d8f7218e4243e8d83414c51
Author: Dan Hansen <[email protected]>
AuthorDate: Mon Aug 14 03:51:18 2023 -0700

    [Models] [Postgres] Check if the dynamically-added index is in the table 
schema before adding (#32731)
    
    * Check if the index is in the table schema before adding
    
    * add pre-condition assertion
    
    * static checks
    
    * Update test_models.py
    
    * integrate upstream auth manager changes
---
 airflow/auth/managers/fab/models/__init__.py |  8 +++-
 tests/auth/managers/fab/test_models.py       | 62 ++++++++++++++++++++++++++++
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/airflow/auth/managers/fab/models/__init__.py 
b/airflow/auth/managers/fab/models/__init__.py
index cb11e8fb06..0bc26adb7e 100644
--- a/airflow/auth/managers/fab/models/__init__.py
+++ b/airflow/auth/managers/fab/models/__init__.py
@@ -255,11 +255,15 @@ class RegisterUser(Model):
 def add_index_on_ab_user_username_postgres(table, conn, **kw):
     if conn.dialect.name != "postgresql":
         return
-    table.indexes.add(Index("idx_ab_user_username", 
func.lower(table.c.username), unique=True))
+    index_name = "idx_ab_user_username"
+    if not any(table_index.name == index_name for table_index in 
table.indexes):
+        table.indexes.add(Index(index_name, func.lower(table.c.username), 
unique=True))
 
 
 @event.listens_for(RegisterUser.__table__, "before_create")
 def add_index_on_ab_register_user_username_postgres(table, conn, **kw):
     if conn.dialect.name != "postgresql":
         return
-    table.indexes.add(Index("idx_ab_register_user_username", 
func.lower(table.c.username), unique=True))
+    index_name = "idx_ab_register_user_username"
+    if not any(table_index.name == index_name for table_index in 
table.indexes):
+        table.indexes.add(Index(index_name, func.lower(table.c.username), 
unique=True))
diff --git a/tests/auth/managers/fab/test_models.py 
b/tests/auth/managers/fab/test_models.py
new file mode 100644
index 0000000000..f2703e8d66
--- /dev/null
+++ b/tests/auth/managers/fab/test_models.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from sqlalchemy import Column, MetaData, String, Table
+
+from airflow.auth.managers.fab.models import (
+    add_index_on_ab_register_user_username_postgres,
+    add_index_on_ab_user_username_postgres,
+)
+
+_mock_conn = mock.MagicMock()
+_mock_conn.dialect = mock.MagicMock()
+_mock_conn.dialect.name = "postgresql"
+
+
+def test_add_index_on_ab_user_username_postgres():
+    table = Table("test_table", MetaData(), Column("username", String))
+
+    assert len(table.indexes) == 0
+
+    add_index_on_ab_user_username_postgres(table, _mock_conn)
+
+    # Assert that the index was added to the table
+    assert len(table.indexes) == 1
+
+    add_index_on_ab_user_username_postgres(table, _mock_conn)
+
+    # Assert that index is not re-added when the schema is recreated
+    assert len(table.indexes) == 1
+
+
+def test_add_index_on_ab_register_user_username_postgres():
+    table = Table("test_table", MetaData(), Column("username", String))
+
+    assert len(table.indexes) == 0
+
+    add_index_on_ab_register_user_username_postgres(table, _mock_conn)
+
+    # Assert that the index was added to the table
+    assert len(table.indexes) == 1
+
+    add_index_on_ab_register_user_username_postgres(table, _mock_conn)
+
+    # Assert that index is not re-added when the schema is recreated
+    assert len(table.indexes) == 1

Reply via email to