This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2500dcf20d2 Move FAB session table creation to FAB provider (#47969)
2500dcf20d2 is described below
commit 2500dcf20d2782d16da53ee857c0aab21bfdfbf2
Author: Jed Cunningham <[email protected]>
AuthorDate: Wed Mar 19 15:41:37 2025 -0600
Move FAB session table creation to FAB provider (#47969)
We need to create the `session` table in the provider db manager, not in
the core db utils.
Co-authored-by: vincbeck <[email protected]>
---
airflow/utils/db.py | 21 -----------------
.../providers/fab/auth_manager/models/db.py | 27 ++++++++++++++++++----
.../tests/unit/fab/auth_manager/models/test_db.py | 4 +++-
3 files changed, 25 insertions(+), 27 deletions(-)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 70e6955c39b..7055a1c7601 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -729,34 +729,15 @@ def create_default_connections(session: Session =
NEW_SESSION):
)
-def _get_flask_db(sql_database_uri):
- from flask import Flask
- from flask_sqlalchemy import SQLAlchemy
-
- from airflow.providers.fab.www.session import
AirflowDatabaseSessionInterface
-
- flask_app = Flask(__name__)
- flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
- flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
- db = SQLAlchemy(flask_app)
- AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session",
key_prefix="")
- return db
-
-
def _create_db_from_orm(session):
log.info("Creating Airflow database tables from the ORM")
from alembic import command
from airflow.models.base import Base
- def _create_flask_session_tbl(sql_database_uri):
- db = _get_flask_db(sql_database_uri)
- db.create_all()
-
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
engine = session.get_bind().engine
Base.metadata.create_all(engine)
- _create_flask_session_tbl(engine.url)
# stamp the migration head
config = _get_alembic_config()
command.stamp(config, "head")
@@ -1254,8 +1235,6 @@ def drop_airflow_models(connection):
from airflow.models.base import Base
Base.metadata.drop_all(connection)
- db = _get_flask_db(connection.engine.url)
- db.drop_all()
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext
diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
index ce0efef55a1..5e2c5397745 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
@@ -31,6 +31,20 @@ _REVISION_HEADS_MAP: dict[str, str] = {
}
+def _get_flask_db(sql_database_uri):
+ from flask import Flask
+ from flask_sqlalchemy import SQLAlchemy
+
+ from airflow.providers.fab.www.session import
AirflowDatabaseSessionInterface
+
+ flask_app = Flask(__name__)
+ flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
+ flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
+ db = SQLAlchemy(flask_app)
+ AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session",
key_prefix="")
+ return db
+
+
class FABDBManager(BaseDBManager):
"""Manages FAB database."""
@@ -40,6 +54,10 @@ class FABDBManager(BaseDBManager):
alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix()
supports_table_dropping = True
+ def _create_db_from_orm(self):
+ super()._create_db_from_orm()
+ _get_flask_db(settings.SQL_ALCHEMY_CONN).create_all()
+
def upgradedb(self, to_revision=None, from_revision=None,
show_sql_only=False):
"""Upgrade the database."""
if from_revision and not show_sql_only:
@@ -68,11 +86,6 @@ class FABDBManager(BaseDBManager):
_offline_migration(command.upgrade, config,
f"{from_revision}:{to_revision}")
return # only running sql; our job is done
- if not self.get_current_revision():
- # New DB; initialize and exit
- self.initdb()
- return
-
command.upgrade(config, revision=to_revision or "heads")
def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
@@ -104,3 +117,7 @@ class FABDBManager(BaseDBManager):
else:
self.log.info("Applying FAB downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
+
+ def drop_tables(self, connection):
+ super().drop_tables(connection)
+ _get_flask_db(settings.SQL_ALCHEMY_CONN).drop_all()
diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
index f0920ebb151..50eaf9450e9 100644
--- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
+++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
@@ -110,10 +110,12 @@ try:
@mock.patch("airflow.utils.db_manager.inspect")
@mock.patch.object(FABDBManager, "metadata")
- def test_drop_tables(self, mock_metadata, mock_inspect, session):
+
@mock.patch("airflow.providers.fab.auth_manager.models.db._get_flask_db")
+ def test_drop_tables(self, mock__get_flask_db, mock_metadata,
mock_inspect, session):
manager = FABDBManager(session)
connection = mock.MagicMock()
manager.drop_tables(connection)
+ mock__get_flask_db.return_value.drop_all.assert_called_once_with()
mock_metadata.drop_all.assert_called_once_with(connection)
@pytest.mark.parametrize("skip_init", [True, False])