This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e23d8cbcd95 Remove findings from positional session check in Core
API+Callback (#67770)
e23d8cbcd95 is described below
commit e23d8cbcd959149c9da817eb59c54ed7497cdbba
Author: Jens Scheffler <[email protected]>
AuthorDate: Sat May 30 14:03:11 2026 +0200
Remove findings from positional session check in Core API+Callback (#67770)
* Fix exceptions of positional session use in airflow-core api
* Fix exceptions of positional session use in airflow-core callback
* Clean pytests as well
---
airflow-core/src/airflow/api/common/delete_dag.py | 2 +-
airflow-core/src/airflow/api/common/mark_tasks.py | 4 +-
.../airflow/callbacks/database_callback_sink.py | 2 +-
.../unit/api_fastapi/common/test_exceptions.py | 12 ++--
.../core_api/routes/public/test_assets.py | 70 +++++++++++-----------
.../core_api/routes/public/test_connections.py | 4 +-
.../core_api/routes/public/test_dag_run.py | 2 +-
.../core_api/routes/public/test_dag_tags.py | 2 +-
.../core_api/routes/public/test_dag_warning.py | 5 +-
.../core_api/routes/public/test_event_logs.py | 5 +-
.../core_api/routes/public/test_import_error.py | 10 +++-
.../api_fastapi/core_api/routes/public/test_job.py | 19 +++---
.../core_api/routes/public/test_monitor.py | 5 +-
.../core_api/routes/public/test_pools.py | 7 ++-
.../core_api/routes/public/test_variables.py | 7 ++-
.../core_api/routes/public/test_xcom.py | 17 ++++--
.../core_api/routes/ui/test_calendar.py | 7 ++-
.../api_fastapi/core_api/routes/ui/test_dags.py | 4 +-
.../api_fastapi/core_api/routes/ui/test_gantt.py | 5 +-
.../api_fastapi/core_api/routes/ui/test_grid.py | 5 +-
.../ci/prek/known_provide_session_positional.txt | 20 -------
21 files changed, 112 insertions(+), 102 deletions(-)
diff --git a/airflow-core/src/airflow/api/common/delete_dag.py
b/airflow-core/src/airflow/api/common/delete_dag.py
index 91cd387e3ac..0e9979762d8 100644
--- a/airflow-core/src/airflow/api/common/delete_dag.py
+++ b/airflow-core/src/airflow/api/common/delete_dag.py
@@ -41,7 +41,7 @@ log = logging.getLogger(__name__)
@provide_session
-def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session
= NEW_SESSION) -> int:
+def delete_dag(dag_id: str, keep_records_in_log: bool = True, *, session:
Session = NEW_SESSION) -> int:
"""
Delete a Dag by a dag_id.
diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py
b/airflow-core/src/airflow/api/common/mark_tasks.py
index 62ce2600aab..915af5952a3 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -144,7 +144,9 @@ def find_task_relatives(
@provide_session
-def get_run_ids(dag: SerializedDAG, run_id: str, future: bool, past: bool,
session: SASession = NEW_SESSION):
+def get_run_ids(
+ dag: SerializedDAG, run_id: str, future: bool, past: bool, *, session:
SASession = NEW_SESSION
+):
"""Return Dag executions' run_ids."""
current_logical_date = session.scalar(
select(DagRun.logical_date).where(DagRun.dag_id == dag.dag_id,
DagRun.run_id == run_id)
diff --git a/airflow-core/src/airflow/callbacks/database_callback_sink.py
b/airflow-core/src/airflow/callbacks/database_callback_sink.py
index e7e27cf86b5..d29ec81ead8 100644
--- a/airflow-core/src/airflow/callbacks/database_callback_sink.py
+++ b/airflow-core/src/airflow/callbacks/database_callback_sink.py
@@ -33,7 +33,7 @@ class DatabaseCallbackSink(BaseCallbackSink):
"""Sends callbacks to database."""
@provide_session
- def send(self, callback: CallbackRequest, session: Session = NEW_SESSION)
-> None:
+ def send(self, callback: CallbackRequest, *, session: Session =
NEW_SESSION) -> None:
"""Send callback for execution."""
db_callback = DbCallbackRequest(callback=callback, priority_weight=1)
session.add(db_callback)
diff --git a/airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
b/airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
index c3958f0297a..fb97f0ac323 100644
--- a/airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
+++ b/airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
@@ -154,9 +154,10 @@ class TestUniqueConstraintErrorHandler:
def test_handle_single_column_unique_constraint_error_without_stacktrace(
self,
mock_get_random_string,
- session,
table,
expected_exception,
+ *,
+ session: Session,
) -> None:
# Take Pool and Variable tables as test cases
# Note: SQLA2 uses a more optimized bulk insert strategy when multiple
objects are added to the
@@ -246,9 +247,10 @@ class TestUniqueConstraintErrorHandler:
def test_handle_single_column_unique_constraint_error_with_stacktrace(
self,
mock_get_random_string,
- session,
table,
expected_exception,
+ *,
+ session: Session,
) -> None:
# Take Pool and Variable tables as test cases
# Note: SQLA2 uses a more optimized bulk insert strategy when multiple
objects are added to the
@@ -279,7 +281,8 @@ class TestUniqueConstraintErrorHandler:
def
test_handle_multiple_columns_unique_constraint_error_without_stacktrace(
self,
mock_get_random_string,
- session,
+ *,
+ session: Session,
) -> None:
expected_exception = HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -351,9 +354,10 @@ class TestUniqueConstraintErrorHandler:
def test_handle_multiple_columns_unique_constraint_error_with_stacktrace(
self,
mock_get_random_string,
- session,
table,
expected_exception,
+ *,
+ session: Session,
) -> None:
if table == "DagRun":
session.add(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
index 298949fce16..1234f25c0d0 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
@@ -261,45 +261,45 @@ class TestAssets:
clear_db_logs()
@provide_session
- def create_assets(self, session, num: int = 2) -> list[AssetModel]:
+ def create_assets(self, num: int = 2, *, session) -> list[AssetModel]:
return _create_assets(session=session, num=num)
@provide_session
- def create_assets_with_watchers(self, session, num: int = 2) ->
list[AssetModel]:
+ def create_assets_with_watchers(self, num: int = 2, *, session) ->
list[AssetModel]:
return _create_assets_with_watchers(session=session, num=num)
@provide_session
- def create_assets_with_sensitive_extra(self, session, num: int = 2):
+ def create_assets_with_sensitive_extra(self, num: int = 2, *, session):
_create_assets_with_sensitive_extra(session=session, num=num)
@provide_session
- def create_provided_asset(self, session, asset: AssetModel):
+ def create_provided_asset(self, asset: AssetModel, *, session):
_create_provided_asset(session=session, asset=asset)
@provide_session
- def create_assets_events(self, session, num: int = 2, varying_timestamps:
bool = False):
+ def create_assets_events(self, num: int = 2, varying_timestamps: bool =
False, *, session):
_create_assets_events(session=session, num=num,
varying_timestamps=varying_timestamps)
@provide_session
- def create_assets_events_with_sensitive_extra(self, session, num: int = 2):
+ def create_assets_events_with_sensitive_extra(self, num: int = 2, *,
session):
_create_assets_events_with_sensitive_extra(session=session, num=num)
@provide_session
- def create_provided_asset_event(self, session, asset_event: AssetEvent):
+ def create_provided_asset_event(self, asset_event: AssetEvent, *, session):
_create_provided_asset_event(session=session, asset_event=asset_event)
@provide_session
- def create_dag_run(self, session, num: int = 2):
+ def create_dag_run(self, num: int = 2, *, session):
_create_dag_run(num=num, session=session)
@provide_session
- def create_asset_dag_run(self, session, num: int = 2):
+ def create_asset_dag_run(self, num: int = 2, *, session):
_create_asset_dag_run(num=num, session=session)
class TestGetAssets(TestAssets):
def test_should_respond_200(self, test_client, session):
- assets1, asset2 = self.create_assets(session)
+ assets1, asset2 = self.create_assets(session=session)
session.add(AssetModel("inactive", "inactive"))
session.commit()
@@ -351,7 +351,7 @@ class TestGetAssets(TestAssets):
def test_should_respond_200_with_watchers(self, test_client, session):
"""Test that assets with watchers return the watcher information in
the API response."""
- asset1, asset2 = self.create_assets_with_watchers(session, num=2)
+ asset1, asset2 = self.create_assets_with_watchers(session=session,
num=2)
response = test_client.get("/assets")
assert response.status_code == 200
@@ -407,7 +407,7 @@ class TestGetAssets(TestAssets):
}
def test_should_show_inactive(self, test_client, session):
- asset1, asset2 = self.create_assets(session)
+ asset1, asset2 = self.create_assets(session=session)
session.add(
asset3 := AssetModel(
name="simple3",
@@ -527,7 +527,7 @@ class TestGetAssets(TestAssets):
],
)
@provide_session
- def test_filter_assets_by_name_pattern_works(self, test_client, params,
expected_assets, session):
+ def test_filter_assets_by_name_pattern_works(self, test_client, params,
expected_assets, *, session):
asset1 = AssetModel("s3-folder-key", "s3://folder/key")
asset2 = AssetModel("gcp-bucket-key", "gcp://bucket/key")
asset3 = AssetModel("some-asset-key", "somescheme://asset/key")
@@ -576,7 +576,7 @@ class TestGetAssets(TestAssets):
],
)
@provide_session
- def test_filter_assets_by_uri_pattern_works(self, test_client, params,
expected_assets, session):
+ def test_filter_assets_by_uri_pattern_works(self, test_client, params,
expected_assets, *, session):
asset1 = AssetModel("s3://folder/key")
asset2 = AssetModel("gcp://bucket/key")
asset3 = AssetModel("somescheme://asset/key")
@@ -594,7 +594,7 @@ class TestGetAssets(TestAssets):
@pytest.mark.parametrize(("dag_ids", "expected_num"), [("dag1,dag2", 2),
("dag3", 1), ("dag2,dag3", 2)])
@provide_session
def test_filter_assets_by_dag_ids_works(
- self, test_client, dag_ids, expected_num, testing_dag_bundle, session
+ self, test_client, dag_ids, expected_num, testing_dag_bundle, *,
session
):
session.execute(delete(DagModel))
session.commit()
@@ -633,7 +633,7 @@ class TestGetAssets(TestAssets):
)
@provide_session
def test_filter_assets_by_dag_ids_and_uri_pattern_works(
- self, test_client, dag_ids, uri_pattern, expected_num,
testing_dag_bundle, session
+ self, test_client, dag_ids, uri_pattern, expected_num,
testing_dag_bundle, *, session
):
session.execute(delete(DagModel))
session.commit()
@@ -719,12 +719,12 @@ class TestAssetAliases:
_create_asset_aliases(num=num, session=session)
@provide_session
- def create_provided_asset_alias(self, asset_alias: AssetAliasModel,
session):
+ def create_provided_asset_alias(self, asset_alias: AssetAliasModel, *,
session):
_create_provided_asset_alias(session=session, asset_alias=asset_alias)
class TestGetAssetAliases(TestAssetAliases):
- def test_should_respond_200(self, test_client, session):
+ def test_should_respond_200(self, test_client, *, session):
self.create_asset_aliases()
asset_aliases = session.scalars(select(AssetAliasModel)).all()
assert len(asset_aliases) == 2
@@ -761,7 +761,9 @@ class TestGetAssetAliases(TestAssetAliases):
],
)
@provide_session
- def test_filter_assets_by_name_pattern_works(self, test_client, params,
expected_asset_aliases, session):
+ def test_filter_assets_by_name_pattern_works(
+ self, test_client, params, expected_asset_aliases, *, session
+ ):
asset_alias1 = AssetAliasModel(name="foo1")
asset_alias2 = AssetAliasModel(name="bar12")
asset_alias3 = AssetAliasModel(name="bar2")
@@ -810,10 +812,10 @@ class
TestGetAssetAliasesEndpointPagination(TestAssetAliases):
class TestGetAssetEvents(TestAssets):
def test_should_respond_200(self, test_client, session):
- asset1, asset2 = self.create_assets(session)
- self.create_assets_events(session)
- self.create_dag_run(session)
- self.create_asset_dag_run(session)
+ asset1, asset2 = self.create_assets(session=session)
+ self.create_assets_events(session=session)
+ self.create_dag_run(session=session)
+ self.create_asset_dag_run(session=session)
assets = session.scalars(select(AssetEvent)).all()
session.commit()
assert len(assets) == 2
@@ -910,7 +912,7 @@ class TestGetAssetEvents(TestAssets):
],
)
@provide_session
- def test_filtering(self, test_client, params, total_entries, session):
+ def test_filtering(self, test_client, params, total_entries, *, session):
self.create_assets()
self.create_assets_events()
self.create_dag_run()
@@ -1076,7 +1078,7 @@ class TestGetAssetEvents(TestAssets):
class TestGetAssetEndpoint(TestAssets):
@provide_session
- def test_should_respond_200(self, test_client, session):
+ def test_should_respond_200(self, test_client, *, session):
self.create_assets(num=1)
assert session.scalars(select(func.count(AssetModel.id))).one() == 1
tz_datetime_format = from_datetime_to_zulu_without_ms(DEFAULT_DATE)
@@ -1100,9 +1102,9 @@ class TestGetAssetEndpoint(TestAssets):
}
@provide_session
- def test_should_respond_200_with_watchers(self, test_client, session):
+ def test_should_respond_200_with_watchers(self, test_client, *, session):
"""Test that single asset endpoint returns watcher information."""
- assets = self.create_assets_with_watchers(session, num=1)
+ assets = self.create_assets_with_watchers(num=1, session=session)
asset = assets[0]
response = test_client.get(f"/assets/{asset.id}")
@@ -1171,7 +1173,7 @@ class TestGetAssetEndpoint(TestAssets):
class TestGetAssetAliasEndpoint(TestAssetAliases):
@provide_session
- def test_should_respond_200(self, test_client, session):
+ def test_should_respond_200(self, test_client, *, session):
self.create_asset_aliases(num=1)
assert session.scalars(select(func.count(AssetAliasModel.id))).one()
== 1
with assert_queries_count(6):
@@ -1296,7 +1298,7 @@ class
TestDeleteDagDatasetQueuedEvents(TestQueuedEventEndpoint):
class TestPostAssetEvents(TestAssets):
@pytest.mark.usefixtures("time_freezer")
def test_should_respond_200(self, test_client, session):
- (asset,) = self.create_assets(session, num=1)
+ (asset,) = self.create_assets(num=1, session=session)
event_payload = {"asset_id": asset.id, "extra": {"foo": "bar"}}
response = test_client.post("/assets/events", json=event_payload)
assert response.status_code == 200
@@ -1326,7 +1328,7 @@ class TestPostAssetEvents(TestAssets):
assert response.status_code == 403
def test_invalid_attr_not_allowed(self, test_client, session):
- self.create_assets(session)
+ self.create_assets(session=session)
event_invalid_payload = {"asset_uri": "s3://bucket/key/1", "extra":
{"foo": "bar"}, "fake": {}}
response = test_client.post("/assets/events",
json=event_invalid_payload)
@@ -1335,7 +1337,7 @@ class TestPostAssetEvents(TestAssets):
@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
- (asset,) = self.create_assets(session, num=1)
+ (asset,) = self.create_assets(num=1, session=session)
event_payload = {"asset_id": asset.id, "extra": {"password": "bar"}}
response = test_client.post("/assets/events", json=event_payload)
assert response.status_code == 200
@@ -1357,7 +1359,7 @@ class TestPostAssetEvents(TestAssets):
def test_should_update_asset_endpoint(self, test_client, session):
"""Test for a single Asset."""
- (asset,) = self.create_assets(session, num=1)
+ (asset,) = self.create_assets(num=1, session=session)
event_payload = {"asset_id": asset.id, "extra": {"foo": "bar"}}
asset_event_response = test_client.post("/assets/events",
json=event_payload)
asset_response = test_client.get(f"/assets/{asset.id}")
@@ -1369,7 +1371,7 @@ class TestPostAssetEvents(TestAssets):
def test_should_update_assets_endpoint(self, test_client, session):
"""Test for multiple Assets."""
- asset1, asset2 = self.create_assets(session, num=2)
+ asset1, asset2 = self.create_assets(num=2, session=session)
# Now, only make a POST to the /assets/events endpoint for one of the
Assets
for _ in range(2):
@@ -1423,7 +1425,7 @@ class TestPostAssetEventsTeamResolution(TestAssets):
],
)
def test_team_resolution(self, test_client, session, multi_team,
expected_teams):
- (asset,) = self.create_assets(session, num=1)
+ (asset,) = self.create_assets(num=1, session=session)
mock_auth_mgr = mock.MagicMock()
mock_auth_mgr.get_authorized_teams.return_value = {"team_a", "team_b"}
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
index 668931bb81b..52343da10b4 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
@@ -63,7 +63,7 @@ TEST_CONN_TYPE_3 = "test_type_3"
@provide_session
-def _create_connection(team_name: str | None = None, session: Session =
NEW_SESSION) -> None:
+def _create_connection(team_name: str | None = None, *, session: Session =
NEW_SESSION) -> None:
connection_model = Connection(
conn_id=TEST_CONN_ID,
conn_type=TEST_CONN_TYPE,
@@ -77,7 +77,7 @@ def _create_connection(team_name: str | None = None, session:
Session = NEW_SESS
@provide_session
-def _create_connections(session: Session = NEW_SESSION) -> None:
+def _create_connections(*, session: Session = NEW_SESSION) -> None:
_create_connection(session=session)
connection_model_2 = Connection(
conn_id=TEST_CONN_ID_2,
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
index 4ffcc12bc76..318544014e3 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -131,7 +131,7 @@ DAG_RUNS_LIST = [DAG1_RUN1_ID, DAG1_RUN2_ID, DAG2_RUN1_ID,
DAG2_RUN2_ID]
@pytest.fixture(autouse=True)
@provide_session
-def setup(request, dag_maker, session=None):
+def setup(request, dag_maker, *, session=None):
clear_db_connections()
clear_db_runs()
clear_db_dags()
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py
index c0302532ab6..82f3d11ed4f 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py
@@ -105,7 +105,7 @@ class TestDagEndpoint:
@pytest.fixture(autouse=True)
@provide_session
- def setup(self, dag_maker, session=None) -> None:
+ def setup(self, dag_maker, *, session=None) -> None:
self._clear_db()
with dag_maker(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py
index 68fb9473d4d..91efee48df7 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py
@@ -18,10 +18,11 @@
from __future__ import annotations
import pytest
+from sqlalchemy.orm import Session
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_dag_warnings, clear_db_dags
@@ -48,7 +49,7 @@ expected_display_names = {
@pytest.fixture(autouse=True)
@provide_session
-def setup(dag_maker, testing_dag_bundle, session=None) -> None:
+def setup(dag_maker, testing_dag_bundle, *, session: Session = NEW_SESSION) ->
None:
clear_db_dags()
clear_db_dag_warnings()
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py
index 2812143b186..00d6918673d 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py
@@ -20,10 +20,11 @@ from datetime import datetime, timezone
from unittest import mock
import pytest
+from sqlalchemy.orm import Session
from airflow.api_fastapi.auth.managers.models.resource_details import
DagAccessEntity, DagDetails
from airflow.models.log import Log
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_logs, clear_db_runs
@@ -62,7 +63,7 @@ class TestEventLogsEndpoint:
@pytest.fixture(autouse=True)
@provide_session
- def setup(self, create_task_instance, session=None) -> dict[str, Log]:
+ def setup(self, create_task_instance, *, session: Session = NEW_SESSION)
-> dict[str, Log]:
"""
Setup event logs for testing.
:return: Dictionary with event log keys and their corresponding IDs.
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
index 886b7266021..ef5427acb69 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
@@ -54,7 +54,7 @@ BUNDLE_NAME = "testing"
@pytest.fixture
@provide_session
-def permitted_dag_model_all(testing_dag_bundle, session: Session =
NEW_SESSION) -> set[str]:
+def permitted_dag_model_all(testing_dag_bundle, *, session: Session =
NEW_SESSION) -> set[str]:
dag_model1 = DagModel(
fileloc=FILENAME1,
relative_fileloc=FILENAME1,
@@ -85,7 +85,7 @@ def permitted_dag_model_all(testing_dag_bundle, session:
Session = NEW_SESSION)
@pytest.fixture
@provide_session
-def not_permitted_dag_model(testing_dag_bundle, session: Session =
NEW_SESSION) -> DagModel:
+def not_permitted_dag_model(testing_dag_bundle, *, session: Session =
NEW_SESSION) -> DagModel:
dag_model = DagModel(
fileloc=FILENAME1,
bundle_name=BUNDLE_NAME,
@@ -113,7 +113,7 @@ def clear_db():
@pytest.fixture(autouse=True)
@provide_session
-def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]:
+def import_errors(*, session: Session = NEW_SESSION) -> list[ParseImportError]:
_import_errors = [
ParseImportError(
bundle_name=bundle,
@@ -546,6 +546,7 @@ class TestImportErrorFileAuthorization:
def absolute_vs_relative_fileloc_dag(
self,
testing_dag_bundle,
+ *,
session: Session = NEW_SESSION,
) -> DagModel:
"""DagModel whose ``fileloc`` is absolute and ``relative_fileloc`` is
@@ -572,6 +573,7 @@ class TestImportErrorFileAuthorization:
def mixed_file_dags(
self,
testing_dag_bundle,
+ *,
session: Session = NEW_SESSION,
) -> tuple[DagModel, DagModel]:
"""Two DagModels pointing at the same ``(relative_fileloc,
@@ -599,6 +601,7 @@ class TestImportErrorFileAuthorization:
@provide_session
def lonely_file_import_error(
self,
+ *,
session: Session = NEW_SESSION,
) -> ParseImportError:
error = ParseImportError(
@@ -615,6 +618,7 @@ class TestImportErrorFileAuthorization:
@provide_session
def mixed_file_import_error(
self,
+ *,
session: Session = NEW_SESSION,
) -> ParseImportError:
error = ParseImportError(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py
index c5874c815d8..a350585dc3a 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py
@@ -19,10 +19,11 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
+from sqlalchemy.orm import Session
from airflow.jobs.job import Job, JobState
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from tests_common.test_utils.asserts import assert_queries_count
@@ -53,10 +54,10 @@ TESTCASE_MULTIPLE_RUNNER =
"should_report_success_for_multiple_runners"
class TestJobEndpoint:
"""Common class for /jobs related unit tests."""
- scheduler_jobs: list[Job] | None = None
- job_runners: list[SchedulerJobRunner] | None = None
+ scheduler_jobs: list[Job]
+ job_runners: list[SchedulerJobRunner]
- def _setup_should_report_success_for_one_working_scheduler(self,
session=None):
+ def _setup_should_report_success_for_one_working_scheduler(self, session:
Session):
scheduler_job = Job()
job_runner = SchedulerJobRunner(job=scheduler_job)
scheduler_job.state = State.RUNNING
@@ -66,7 +67,7 @@ class TestJobEndpoint:
self.job_runners.append(job_runner)
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)
- def
_setup_should_report_success_for_one_working_scheduler_with_hostname(self,
session=None):
+ def
_setup_should_report_success_for_one_working_scheduler_with_hostname(self,
session: Session):
scheduler_job = Job()
job_runner = SchedulerJobRunner(job=scheduler_job)
scheduler_job.state = State.RUNNING
@@ -77,7 +78,7 @@ class TestJobEndpoint:
session.commit()
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)
- def _setup_should_report_success_for_ha_schedulers(self, session=None):
+ def _setup_should_report_success_for_ha_schedulers(self, session: Session):
for _ in range(3):
scheduler_job = Job()
job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -88,7 +89,7 @@ class TestJobEndpoint:
session.commit()
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)
- def _setup_should_ignore_not_running_jobs(self, session=None):
+ def _setup_should_ignore_not_running_jobs(self, session: Session):
for _ in range(3):
scheduler_job = Job()
job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -98,7 +99,7 @@ class TestJobEndpoint:
self.job_runners.append(job_runner)
session.commit()
- def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self,
session=None):
+ def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self,
session: Session):
for _ in range(3):
scheduler_job = Job()
job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -112,7 +113,7 @@ class TestJobEndpoint:
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)
@provide_session
- def setup(self, testcase: TestCase, session=None) -> None:
+ def setup(self, testcase: TestCase, *, session: Session = NEW_SESSION) ->
None:
"""
Setup testcase at runtime based on the `testcase` provided by
`pytest.mark.parametrize`.
"""
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py
index 5175278d777..d4b3d97f932 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py
@@ -20,6 +20,7 @@ from datetime import timedelta
from unittest import mock
import pytest
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.jobs.job import Job
@@ -46,7 +47,7 @@ class TestMonitorEndpoint:
class TestGetHealth(TestMonitorEndpoint):
@provide_session
- def test_healthy_scheduler_status(self, test_client, session):
+ def test_healthy_scheduler_status(self, test_client, *, session: Session):
last_scheduler_heartbeat_for_testing_1 = timezone.utcnow()
job = Job(state=State.RUNNING,
latest_heartbeat=last_scheduler_heartbeat_for_testing_1)
SchedulerJobRunner(job=job)
@@ -65,7 +66,7 @@ class TestGetHealth(TestMonitorEndpoint):
)
@provide_session
- def test_unhealthy_scheduler_is_slow(self, test_client, session):
+ def test_unhealthy_scheduler_is_slow(self, test_client, *, session:
Session):
last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() -
timedelta(minutes=1)
job = Job(state=State.RUNNING,
latest_heartbeat=last_scheduler_heartbeat_for_testing_2)
SchedulerJobRunner(job=job)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
index 4cd953d9518..6e17598f07e 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
@@ -20,10 +20,11 @@ from unittest import mock
import pytest
from sqlalchemy import func, select
+from sqlalchemy.orm import Session
from airflow.models.pool import Pool
from airflow.models.team import Team
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from tests_common.test_utils.asserts import count_queries
from tests_common.test_utils.config import conf_vars
@@ -50,7 +51,7 @@ POOL3_DESCRIPTION = "Some Description"
@provide_session
-def _create_pools(session) -> None:
+def _create_pools(*, session: Session = NEW_SESSION) -> None:
pool1 = Pool(pool=POOL1_NAME, slots=POOL1_SLOT,
include_deferred=POOL1_INCLUDE_DEFERRED, team_name="test")
pool2 = Pool(pool=POOL2_NAME, slots=POOL2_SLOT,
include_deferred=POOL2_INCLUDE_DEFERRED)
pool3 = Pool(
@@ -63,7 +64,7 @@ def _create_pools(session) -> None:
@provide_session
-def _create_team(session) -> None:
+def _create_team(*, session: Session = NEW_SESSION) -> None:
session.add(Team(name="test"))
session.commit()
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
index b75bf091cba..cd3ecbb33ee 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
@@ -23,10 +23,11 @@ from unittest.mock import ANY
import pytest
from sqlalchemy import select
+from sqlalchemy.orm import Session
from airflow.models.team import Team
from airflow.models.variable import Variable
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.config import conf_vars
@@ -73,7 +74,7 @@ def create_file_upload(content: dict) -> BytesIO:
@provide_session
-def _create_variables(session) -> None:
+def _create_variables(*, session: Session = NEW_SESSION) -> None:
team = session.scalars(select(Team).where(Team.name == "test")).one()
Variable.set(
@@ -128,7 +129,7 @@ def _create_variables(session) -> None:
@provide_session
-def _create_team(session) -> None:
+def _create_team(*, session: Session = NEW_SESSION) -> None:
session.add(Team(name="test"))
session.commit()
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py
index 07dae5ef6bb..81723e05a8e 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py
@@ -17,9 +17,11 @@
from __future__ import annotations
import json
+from typing import TYPE_CHECKING
from unittest import mock
import pytest
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.api_fastapi.core_api.datamodels.xcom import XComCreateBody
@@ -31,7 +33,7 @@ from airflow.providers.standard.operators.empty import
EmptyOperator
from airflow.sdk import DAG, AssetAlias
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.execution_time.xcom import resolve_xcom_backend
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import DagRunType
from tests_common.test_utils.asserts import assert_queries_count
@@ -42,6 +44,9 @@ from tests_common.test_utils.logs import check_last_log
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.taskinstance import create_task_instance
+if TYPE_CHECKING:
+ from airflow.sdk.types import MappedOperator
+
pytestmark = pytest.mark.db_test
TEST_XCOM_KEY = "test_xcom_key"
@@ -70,7 +75,7 @@ run_id = DagRun.generate_run_id(
@provide_session
-def _create_xcom(key, value, backend, session=None) -> None:
+def _create_xcom(key, value, backend, *, session: Session = NEW_SESSION) ->
None:
XComModel.set(
key=key,
value=value,
@@ -82,7 +87,7 @@ def _create_xcom(key, value, backend, session=None) -> None:
@provide_session
-def _create_dag_run(dag_maker, session=None):
+def _create_dag_run(dag_maker, *, session: Session = NEW_SESSION):
with dag_maker(TEST_DAG_ID, schedule=None, start_date=logical_date_parsed):
EmptyOperator(task_id=TEST_TASK_ID)
dag_maker.create_dagrun(
@@ -443,13 +448,16 @@ class TestGetXComEntries(TestXComEndpoint):
}
@provide_session
- def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id,
mapped_ti=False, session=None):
+ def _create_xcom_entries(
+ self, dag_id, run_id, logical_date, task_id, mapped_ti=False, *,
session: Session = NEW_SESSION
+ ):
bundle_name = "testing"
orm_dag_bundle = DagBundleModel(name=bundle_name)
session.merge(orm_dag_bundle)
session.flush()
with DAG(dag_id=dag_id) as dag:
+ task: EmptyOperator | MappedOperator
if mapped_ti:
task = MockOperator.partial(task_id=task_id).expand(arg1=[0,
1])
else:
@@ -464,6 +472,7 @@ class TestGetXComEntries(TestXComEndpoint):
)
session.add(dagrun)
dag_version = DagVersion.get_latest_version(dag.dag_id)
+ assert dag_version
if mapped_ti:
for i in [0, 1]:
ti = create_task_instance(task, run_id=run_id, map_index=i,
dag_version_id=dag_version.id)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py
index bcf3e4aa441..7bd0d69153c 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py
@@ -21,10 +21,11 @@ from datetime import datetime
import pendulum
import pytest
+from sqlalchemy.orm import Session
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import CronPartitionTimetable
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from tests_common.test_utils.asserts import assert_queries_count
@@ -38,7 +39,7 @@ class TestCalendar:
@pytest.fixture(autouse=True)
@provide_session
- def setup_dag_runs(self, dag_maker, session=None) -> None:
+ def setup_dag_runs(self, dag_maker, *, session: Session = NEW_SESSION) ->
None:
clear_db_runs()
clear_db_dags()
with dag_maker(
@@ -192,7 +193,7 @@ class TestPartitionedCalendar:
@pytest.fixture(autouse=True)
@provide_session
- def setup_dag_runs(self, dag_maker, session=None) -> None:
+ def setup_dag_runs(self, dag_maker, *, session: Session = NEW_SESSION) ->
None:
clear_db_runs()
clear_db_dags()
with dag_maker(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py
index 47fac78d629..bb94168bd57 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py
@@ -31,7 +31,7 @@ from airflow.models.dag import DagModel, DagTag
from airflow.models.dag_favorite import DagFavorite
from airflow.models.hitl import HITLDetail
from airflow.sdk.timezone import utcnow
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -54,7 +54,7 @@ pytestmark = pytest.mark.db_test
class TestGetDagRuns(TestPublicDagEndpoint):
@pytest.fixture(autouse=True)
@provide_session
- def setup_dag_runs(self, session=None) -> None:
+ def setup_dag_runs(self, *, session: Session = NEW_SESSION) -> None:
# Create DAG Runs
for dag_id in [DAG1_ID, DAG2_ID, DAG3_ID, DAG4_ID, DAG5_ID]:
dag_runs_count = 5 if dag_id in [DAG1_ID, DAG2_ID] else 2
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py
index 0e2be9e277c..8e7e192bc49 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py
@@ -21,11 +21,12 @@ from operator import attrgetter
import pendulum
import pytest
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.models.dagbag import DBDagBag
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -93,7 +94,7 @@ def examples_dag_bag():
@pytest.fixture(autouse=True)
@provide_session
-def setup(dag_maker, session=None):
+def setup(dag_maker, *, session: Session = NEW_SESSION):
clear_db_runs()
clear_db_dags()
clear_db_serialized_dags()
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
index b0dab9012a5..0c8f4c17449 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
@@ -24,6 +24,7 @@ from operator import attrgetter
import pendulum
import pytest
from sqlalchemy import select
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.models.dag import DagModel
@@ -33,7 +34,7 @@ from airflow.providers.standard.operators.empty import
EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import task_group
from airflow.sdk.definitions.taskgroup import TaskGroup
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -151,7 +152,7 @@ def examples_dag_bag():
@pytest.fixture(autouse=True)
@provide_session
-def setup(dag_maker, session=None):
+def setup(dag_maker, *, session: Session = NEW_SESSION):
clear_db_runs()
clear_db_dags()
clear_db_serialized_dags()
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
index 1a773d36c0d..6fa731efdb0 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -1,6 +1,3 @@
-airflow-core/src/airflow/api/common/delete_dag.py::1
-airflow-core/src/airflow/api/common/mark_tasks.py::1
-airflow-core/src/airflow/callbacks/database_callback_sink.py::1
airflow-core/src/airflow/cli/commands/dag_command.py::8
airflow-core/src/airflow/cli/commands/jobs_command.py::1
airflow-core/src/airflow/cli/commands/task_command.py::1
@@ -47,23 +44,6 @@ airflow-core/src/airflow/utils/cli_action_loggers.py::1
airflow-core/src/airflow/utils/db.py::7
airflow-core/src/airflow/utils/db_cleanup.py::2
airflow-core/src/airflow/utils/log/file_task_handler.py::1
-airflow-core/tests/unit/api_fastapi/common/test_exceptions.py::4
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py::19
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py::2
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py::7
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py::2
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py::2
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py::2
-airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py::3
-airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py::2
-airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py::1
-airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py::1
airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py::2
airflow-core/tests/unit/jobs/test_scheduler_job.py::1
airflow-core/tests/unit/listeners/test_listeners.py::7