This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1c7c97630e Replace `unittests` test cases by pure `pytest` [Wave-1]
(#26831)
1c7c97630e is described below
commit 1c7c97630ee226e2dcb3ca90da25dc6a8fbb5b21
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Oct 10 10:43:11 2022 +0400
Replace `unittests` test cases by pure `pytest` [Wave-1] (#26831)
---
tests/always/test_secrets.py | 9 +-
tests/always/test_secrets_backends.py | 21 ++--
tests/always/test_secrets_local_filesystem.py | 138 +++++++++++----------
tests/api/auth/test_client.py | 3 +-
tests/api/client/test_local_client.py | 12 +-
tests/api/common/test_trigger_dag.py | 15 ++-
tests/api_connexion/schemas/test_common_schema.py | 9 +-
.../schemas/test_connection_schema.py | 21 ++--
tests/api_connexion/schemas/test_dag_run_schema.py | 14 +--
tests/api_connexion/schemas/test_error_schema.py | 8 +-
tests/api_connexion/schemas/test_health_schema.py | 6 +-
tests/api_connexion/schemas/test_plugin_schema.py | 6 +-
tests/api_connexion/schemas/test_pool_schemas.py | 14 +--
.../schemas/test_task_instance_schema.py | 108 +++++++---------
tests/api_connexion/schemas/test_version_schema.py | 13 +-
tests/api_connexion/test_parameters.py | 13 +-
tests/cli/commands/test_celery_command.py | 23 ++--
tests/cli/commands/test_cheat_sheet_command.py | 5 +-
tests/cli/commands/test_config_command.py | 9 +-
tests/cli/commands/test_dag_command.py | 7 +-
tests/cli/commands/test_dag_processor_command.py | 5 +-
tests/cli/commands/test_info_command.py | 12 +-
tests/cli/commands/test_jobs_command.py | 15 ++-
tests/cli/commands/test_kerberos_command.py | 5 +-
tests/cli/commands/test_kubernetes_command.py | 9 +-
tests/cli/commands/test_legacy_commands.py | 5 +-
tests/cli/commands/test_plugins_command.py | 7 +-
.../cli/commands/test_rotate_fernet_key_command.py | 9 +-
tests/cli/commands/test_scheduler_command.py | 44 +++----
tests/cli/commands/test_sync_perm_command.py | 5 +-
tests/cli/commands/test_triggerer_command.py | 5 +-
tests/cli/commands/test_variable_command.py | 9 +-
tests/cli/commands/test_version_command.py | 5 +-
tests/cli/commands/test_webserver_command.py | 13 +-
tests/cli/test_cli_parser.py | 21 ++--
tests/core/test_config_templates.py | 17 +--
tests/core/test_logging_config.py | 13 +-
tests/core/test_settings.py | 24 ++--
tests/core/test_sqlalchemy_config.py | 7 +-
tests/core/test_stats.py | 22 ++--
tests/dag_processing/test_manager.py | 7 +-
tests/executors/test_dask_executor.py | 17 ++-
tests/executors/test_executor_loader.py | 24 ++--
tests/executors/test_kubernetes_executor.py | 16 +--
tests/executors/test_local_executor.py | 3 +-
tests/executors/test_sequential_executor.py | 3 +-
tests/hooks/test_subprocess.py | 24 ++--
tests/kubernetes/models/test_secret.py | 3 +-
tests/kubernetes/test_client.py | 9 +-
tests/macros/test_hive.py | 3 +-
tests/models/test_dagcode.py | 7 +-
tests/models/test_param.py | 4 +-
tests/operators/test_branch_operator.py | 13 +-
tests/operators/test_trigger_dagrun.py | 8 +-
tests/operators/test_weekday.py | 57 ++++-----
tests/plugins/test_plugin_ignore.py | 7 +-
tests/sensors/test_bash.py | 5 +-
tests/sensors/test_filesystem.py | 5 +-
tests/sensors/test_time_delta.py | 5 +-
tests/sensors/test_timeout_sensor.py | 5 +-
tests/sensors/test_weekday_sensor.py | 44 ++++---
tests/task/task_runner/test_cgroup_task_runner.py | 3 +-
tests/task/task_runner/test_task_runner.py | 7 +-
.../deps/test_dag_ti_slots_available_dep.py | 3 +-
tests/ti_deps/deps/test_dag_unpaused_dep.py | 3 +-
tests/ti_deps/deps/test_dagrun_exists_dep.py | 3 +-
tests/ti_deps/deps/test_dagrun_id_dep.py | 3 +-
tests/ti_deps/deps/test_not_in_retry_period_dep.py | 3 +-
.../ti_deps/deps/test_pool_slots_available_dep.py | 7 +-
tests/ti_deps/deps/test_ready_to_reschedule_dep.py | 3 +-
tests/ti_deps/deps/test_task_concurrency.py | 3 +-
tests/ti_deps/deps/test_task_not_running_dep.py | 3 +-
tests/ti_deps/deps/test_valid_state_dep.py | 3 +-
tests/utils/log/test_file_processor_handler.py | 8 +-
tests/utils/log/test_json_formatter.py | 3 +-
tests/utils/test_dag_cycle.py | 4 +-
tests/utils/test_dates.py | 5 +-
tests/utils/test_docs.py | 10 +-
tests/utils/test_email.py | 5 +-
tests/utils/test_event_scheduler.py | 3 +-
tests/utils/test_file.py | 5 +-
tests/utils/test_json.py | 7 +-
tests/utils/test_logging_mixin.py | 7 +-
tests/utils/test_module_loading.py | 4 +-
tests/utils/test_net.py | 3 +-
tests/utils/test_operator_helpers.py | 6 +-
tests/utils/test_operator_resources.py | 4 +-
...test_preexisting_python_virtualenv_decorator.py | 4 +-
tests/utils/test_python_virtualenv.py | 3 +-
tests/utils/test_sqlalchemy.py | 23 ++--
tests/utils/test_timezone.py | 3 +-
tests/utils/test_trigger_rule.py | 4 +-
tests/utils/test_weekday.py | 50 ++++----
tests/utils/test_weight_rule.py | 4 +-
tests/www/test_app.py | 5 +-
tests/www/test_init_views.py | 3 +-
tests/www/test_utils.py | 13 +-
tests/www/test_validators.py | 11 +-
98 files changed, 547 insertions(+), 669 deletions(-)
diff --git a/tests/always/test_secrets.py b/tests/always/test_secrets.py
index df076eead3..4c67a46fe8 100644
--- a/tests/always/test_secrets.py
+++ b/tests/always/test_secrets.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.configuration import ensure_secrets_loaded,
initialize_secrets_backends
@@ -26,7 +25,7 @@ from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_variables
-class TestConnectionsFromSecrets(unittest.TestCase):
+class TestConnectionsFromSecrets:
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_get_connection_second_try(self, mock_env_get, mock_meta_get):
@@ -112,11 +111,11 @@ class TestConnectionsFromSecrets(unittest.TestCase):
assert 'mysql://airflow:airflow@host:5432/airflow' == conn.get_uri()
-class TestVariableFromSecrets(unittest.TestCase):
- def setUp(self) -> None:
+class TestVariableFromSecrets:
+ def setup_method(self) -> None:
clear_db_variables()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_variables()
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_variable")
diff --git a/tests/always/test_secrets_backends.py
b/tests/always/test_secrets_backends.py
index 947053848c..212a578310 100644
--- a/tests/always/test_secrets_backends.py
+++ b/tests/always/test_secrets_backends.py
@@ -18,10 +18,9 @@
from __future__ import annotations
import os
-import unittest
from unittest import mock
-from parameterized import parameterized
+import pytest
from airflow.models.connection import Connection
from airflow.models.variable import Variable
@@ -41,21 +40,23 @@ class SampleConn:
self.conn = Connection(conn_id=self.conn_id, uri=self.conn_uri)
-class TestBaseSecretsBackend(unittest.TestCase):
- def setUp(self) -> None:
+class TestBaseSecretsBackend:
+ def setup_method(self) -> None:
clear_db_variables()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_connections()
clear_db_variables()
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "kwargs, output",
[
- ('default', {"path_prefix": "PREFIX", "secret_id": "ID"},
"PREFIX/ID"),
- ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep":
"-"}, "PREFIX-ID"),
- ]
+ ({"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"),
+ ({"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"},
"PREFIX-ID"),
+ ],
+ ids=["default", "with_sep"],
)
- def test_build_path(self, _, kwargs, output):
+ def test_build_path(self, kwargs, output):
build_path = BaseSecretsBackend.build_path
assert build_path(**kwargs) == output
diff --git a/tests/always/test_secrets_local_filesystem.py
b/tests/always/test_secrets_local_filesystem.py
index 2d2a49a1a6..bb2c40abf7 100644
--- a/tests/always/test_secrets_local_filesystem.py
+++ b/tests/always/test_secrets_local_filesystem.py
@@ -18,13 +18,11 @@ from __future__ import annotations
import json
import re
-import unittest
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from unittest import mock
import pytest
-from parameterized import parameterized
from airflow.configuration import ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowFileParseException,
ConnectionNotUnique
@@ -42,24 +40,26 @@ def mock_local_file(content):
yield file_mock
-class FileParsers(unittest.TestCase):
- @parameterized.expand(
- (
+class TestFileParsers:
+ @pytest.mark.parametrize(
+ "content, expected_message",
+ [
("AA", 'Invalid line format. The line should contain at least one
equal sign ("=")'),
("=", "Invalid line format. Key is empty."),
- )
+ ],
)
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
with pytest.raises(AirflowFileParseException,
match=re.escape(expected_message)):
local_filesystem.load_variables("a.env")
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "content, expected_message",
+ [
("[]", "The file should contain the object."),
("{AAAAA}", "Expecting property name enclosed in double quotes"),
("", "The file is empty."),
- )
+ ],
)
def test_json_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
@@ -67,34 +67,41 @@ class FileParsers(unittest.TestCase):
local_filesystem.load_variables("a.json")
-class TestLoadVariables(unittest.TestCase):
- @parameterized.expand(
- (
+class TestLoadVariables:
+ @pytest.mark.parametrize(
+ "file_content, expected_variables",
+ [
("", {}),
("KEY=AAA", {"KEY": "AAA"}),
("KEY_A=AAA\nKEY_B=BBB", {"KEY_A": "AAA", "KEY_B": "BBB"}),
("KEY_A=AAA\n # AAAA\nKEY_B=BBB", {"KEY_A": "AAA", "KEY_B":
"BBB"}),
("\n\n\n\nKEY_A=AAA\n\n\n\n\nKEY_B=BBB\n\n\n", {"KEY_A": "AAA",
"KEY_B": "BBB"}),
- )
+ ],
)
def test_env_file_should_load_variables(self, file_content,
expected_variables):
with mock_local_file(file_content):
variables = local_filesystem.load_variables("a.env")
assert expected_variables == variables
- @parameterized.expand((("AA=A\nAA=B", "The \"a.env\" file contains
multiple values for keys: ['AA']"),))
+ @pytest.mark.parametrize(
+ "content, expected_message",
+ [
+ ("AA=A\nAA=B", "The \"a.env\" file contains multiple values for
keys: ['AA']"),
+ ],
+ )
def test_env_file_invalid_logic(self, content, expected_message):
with mock_local_file(content):
with pytest.raises(AirflowException,
match=re.escape(expected_message)):
local_filesystem.load_variables("a.env")
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_variables",
+ [
({}, {}),
({"KEY": "AAA"}, {"KEY": "AAA"}),
({"KEY_A": "AAA", "KEY_B": "BBB"}, {"KEY_A": "AAA", "KEY_B":
"BBB"}),
({"KEY_A": "AAA", "KEY_B": "BBB"}, {"KEY_A": "AAA", "KEY_B":
"BBB"}),
- )
+ ],
)
def test_json_file_should_load_variables(self, file_content,
expected_variables):
with mock_local_file(json.dumps(file_content)):
@@ -109,8 +116,9 @@ class TestLoadVariables(unittest.TestCase):
):
local_filesystem.load_variables("a.json")
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_variables",
+ [
("KEY: AAA", {"KEY": "AAA"}),
(
"""
@@ -119,7 +127,7 @@ class TestLoadVariables(unittest.TestCase):
""",
{"KEY_A": "AAA", "KEY_B": "BBB"},
),
- )
+ ],
)
def test_yaml_file_should_load_variables(self, file_content,
expected_variables):
with mock_local_file(file_content):
@@ -128,9 +136,10 @@ class TestLoadVariables(unittest.TestCase):
assert expected_variables == vars_yaml == vars_yml
-class TestLoadConnection(unittest.TestCase):
- @parameterized.expand(
- (
+class TestLoadConnection:
+ @pytest.mark.parametrize(
+ "file_content, expected_connection_uris",
+ [
("CONN_ID=mysql://host_1/", {"CONN_ID": "mysql://host_1"}),
(
"CONN_ID1=mysql://host_1/\nCONN_ID2=mysql://host_2/",
@@ -144,7 +153,7 @@ class TestLoadConnection(unittest.TestCase):
"\n\n\n\nCONN_ID1=mysql://host_1/\n\n\n\n\nCONN_ID2=mysql://host_2/\n\n\n",
{"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
- )
+ ],
)
def test_env_file_should_load_connection(self, file_content,
expected_connection_uris):
with mock_local_file(file_content):
@@ -155,13 +164,14 @@ class TestLoadConnection(unittest.TestCase):
assert expected_connection_uris == connection_uris_by_conn_id
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "content, expected_connection_uris",
+ [
(
"CONN_ID=mysql://host_1/?param1=val1¶m2=val2",
{"CONN_ID": "mysql://host_1/?param1=val1¶m2=val2"},
),
- )
+ ],
)
def test_parsing_with_params(self, content, expected_connection_uris):
with mock_local_file(content):
@@ -172,24 +182,26 @@ class TestLoadConnection(unittest.TestCase):
assert expected_connection_uris == connection_uris_by_conn_id
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "content, expected_message",
+ [
("AA", 'Invalid line format. The line should contain at least one
equal sign ("=")'),
("=", "Invalid line format. Key is empty."),
- )
+ ],
)
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
with pytest.raises(AirflowFileParseException,
match=re.escape(expected_message)):
local_filesystem.load_connections_dict("a.env")
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_connection_uris",
+ [
({"CONN_ID": "mysql://host_1"}, {"CONN_ID": "mysql://host_1"}),
({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": "mysql://host_1"}),
({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID":
"mysql://host_1"}),
({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID":
"mysql://host_1"}),
- )
+ ],
)
def test_json_file_should_load_connection(self, file_content,
expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
@@ -200,8 +212,9 @@ class TestLoadConnection(unittest.TestCase):
assert expected_connection_uris == connection_uris_by_conn_id
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_connection_uris",
+ [
({"CONN_ID": None}, "Unexpected value type: <class 'NoneType'>."),
({"CONN_ID": 1}, "Unexpected value type: <class 'int'>."),
({"CONN_ID": [2]}, "Unexpected value type: <class 'int'>."),
@@ -209,7 +222,7 @@ class TestLoadConnection(unittest.TestCase):
({"CONN_ID": {"AAA": "mysql://host_1"}}, "The object have illegal
keys: AAA."),
({"CONN_ID": {"conn_id": "BBBB"}}, "Mismatch conn_id."),
({"CONN_ID": ["mysql://", "mysql://"]}, "Found multiple values for
CONN_ID in a.json."),
- )
+ ],
)
def test_env_file_invalid_input(self, file_content,
expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
@@ -224,8 +237,9 @@ class TestLoadConnection(unittest.TestCase):
):
local_filesystem.load_connections_dict("a.json")
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_attrs_dict",
+ [
(
"""CONN_A: 'mysql://host_a'""",
{"CONN_A": {'conn_type': 'mysql', 'host': 'host_a'}},
@@ -262,7 +276,7 @@ class TestLoadConnection(unittest.TestCase):
},
},
),
- )
+ ],
)
def test_yaml_file_should_load_connection(self, file_content,
expected_attrs_dict):
with mock_local_file(file_content):
@@ -272,8 +286,9 @@ class TestLoadConnection(unittest.TestCase):
actual_attrs = {k: getattr(connection, k) for k in
expected_attrs.keys()}
assert actual_attrs == expected_attrs
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_extras",
+ [
(
"""
conn_c:
@@ -323,7 +338,7 @@ class TestLoadConnection(unittest.TestCase):
""",
{"conn_d": {"extra__google_cloud_platform__keyfile_dict":
{"a": "b"}}},
),
- )
+ ],
)
def test_yaml_file_should_load_connection_extras(self, file_content,
expected_extras):
with mock_local_file(file_content):
@@ -333,8 +348,9 @@ class TestLoadConnection(unittest.TestCase):
}
assert expected_extras == connection_uris_by_conn_id
- @parameterized.expand(
- (
+ @pytest.mark.parametrize(
+ "file_content, expected_message",
+ [
(
"""conn_c:
conn_type: scheme
@@ -351,50 +367,46 @@ class TestLoadConnection(unittest.TestCase):
""",
"The extra and extra_dejson parameters are mutually
exclusive.",
),
- )
+ ],
)
def test_yaml_invalid_extra(self, file_content, expected_message):
with mock_local_file(file_content):
with pytest.raises(AirflowException,
match=re.escape(expected_message)):
local_filesystem.load_connections_dict("a.yaml")
- @parameterized.expand(
- ("CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/",),
- )
+ @pytest.mark.parametrize("file_content",
["CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/"])
def test_ensure_unique_connection_env(self, file_content):
with mock_local_file(file_content):
with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.env")
- @parameterized.expand(
- (
- ({"CONN_ID": ["mysql://host_1", "mysql://host_2"]},),
- ({"CONN_ID": [{"uri": "mysql://host_1"}, {"uri":
"mysql://host_2"}]},),
- )
+ @pytest.mark.parametrize(
+ "file_content",
+ [
+ {"CONN_ID": ["mysql://host_1", "mysql://host_2"]},
+ {"CONN_ID": [{"uri": "mysql://host_1"}, {"uri":
"mysql://host_2"}]},
+ ],
)
def test_ensure_unique_connection_json(self, file_content):
with mock_local_file(json.dumps(file_content)):
with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.json")
- @parameterized.expand(
- (
- (
- """
+ @pytest.mark.parametrize(
+ "file_content",
+ [
+ """
conn_a:
- mysql://hosta
- mysql://hostb"""
- ),
- ),
+ ],
)
def test_ensure_unique_connection_yaml(self, file_content):
with mock_local_file(file_content):
with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.yaml")
- @parameterized.expand(
- (("conn_a: mysql://hosta"),),
- )
+ @pytest.mark.parametrize("file_content", ["conn_a: mysql://hosta"])
def test_yaml_extension_parsers_return_same_result(self, file_content):
with mock_local_file(file_content):
conn_uri_by_conn_id_yaml = {
@@ -408,7 +420,7 @@ class TestLoadConnection(unittest.TestCase):
assert conn_uri_by_conn_id_yaml == conn_uri_by_conn_id_yml
-class TestLocalFileBackend(unittest.TestCase):
+class TestLocalFileBackend:
def test_should_read_variable(self):
with NamedTemporaryFile(suffix="var.env") as tmp_file:
tmp_file.write(b"KEY_A=VAL_A")
diff --git a/tests/api/auth/test_client.py b/tests/api/auth/test_client.py
index d0454d623f..0f6758ad2b 100644
--- a/tests/api/auth/test_client.py
+++ b/tests/api/auth/test_client.py
@@ -16,14 +16,13 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.api.client import get_current_api_client
from tests.test_utils.config import conf_vars
-class TestGetCurrentApiClient(unittest.TestCase):
+class TestGetCurrentApiClient:
@mock.patch("airflow.api.client.json_client.Client")
@mock.patch("airflow.api.auth.backend.default.CLIENT_AUTH", "CLIENT_AUTH")
@conf_vars(
diff --git a/tests/api/client/test_local_client.py
b/tests/api/client/test_local_client.py
index 2210bc2f3c..3fecedabb6 100644
--- a/tests/api/client/test_local_client.py
+++ b/tests/api/client/test_local_client.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import json
import random
import string
-import unittest
from unittest.mock import patch
import pendulum
@@ -42,20 +41,17 @@ EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0)
EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat()
-class TestLocalClient(unittest.TestCase):
+class TestLocalClient:
@classmethod
- def setUpClass(cls):
- super().setUpClass()
+ def setup_class(cls):
DagBag(example_bash_operator.__file__).get_dag("example_bash_operator").sync_to_db()
- def setUp(self):
- super().setUp()
+ def setup_method(self):
clear_db_pools()
self.client = Client(api_base_url=None, auth=None)
- def tearDown(self):
+ def teardown_method(self):
clear_db_pools()
- super().tearDown()
@patch.object(DAG, 'create_dagrun')
def test_trigger_dag(self, mock):
diff --git a/tests/api/common/test_trigger_dag.py
b/tests/api/common/test_trigger_dag.py
index 5d410d47a9..43b6eb2f15 100644
--- a/tests/api/common/test_trigger_dag.py
+++ b/tests/api/common/test_trigger_dag.py
@@ -17,11 +17,9 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
import pytest
-from parameterized import parameterized
from airflow.api.common.trigger_dag import _trigger_dag
from airflow.exceptions import AirflowException
@@ -30,11 +28,11 @@ from airflow.utils import timezone
from tests.test_utils import db
-class TestTriggerDag(unittest.TestCase):
- def setUp(self) -> None:
+class TestTriggerDag:
+ def setup_method(self) -> None:
db.clear_db_runs()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
db.clear_db_runs()
@mock.patch('airflow.models.DagBag')
@@ -108,15 +106,16 @@ class TestTriggerDag(unittest.TestCase):
assert len(triggers) == 1
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "conf, expected_conf",
[
(None, {}),
({"foo": "bar"}, {"foo": "bar"}),
('{"foo": "bar"}', {"foo": "bar"}),
- ]
+ ],
)
@mock.patch('airflow.models.DagBag')
- def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock):
+ def test_trigger_dag_with_conf(self, dag_bag_mock, conf, expected_conf):
dag_id = "trigger_dag_with_conf"
dag = DAG(dag_id)
dag_bag_mock.dags = [dag_id]
diff --git a/tests/api_connexion/schemas/test_common_schema.py
b/tests/api_connexion/schemas/test_common_schema.py
index 00a7f05d1e..759859c9a8 100644
--- a/tests/api_connexion/schemas/test_common_schema.py
+++ b/tests/api_connexion/schemas/test_common_schema.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import datetime
-import unittest
import pytest
from dateutil import relativedelta
@@ -31,7 +30,7 @@ from airflow.api_connexion.schemas.common_schema import (
)
-class TestTimeDeltaSchema(unittest.TestCase):
+class TestTimeDeltaSchema:
def test_should_serialize(self):
instance = datetime.timedelta(days=12)
schema_instance = TimeDeltaSchema()
@@ -46,7 +45,7 @@ class TestTimeDeltaSchema(unittest.TestCase):
assert expected_instance == result
-class TestRelativeDeltaSchema(unittest.TestCase):
+class TestRelativeDeltaSchema:
def test_should_serialize(self):
instance = relativedelta.relativedelta(days=+12)
schema_instance = RelativeDeltaSchema()
@@ -78,7 +77,7 @@ class TestRelativeDeltaSchema(unittest.TestCase):
assert expected_instance == result
-class TestCronExpressionSchema(unittest.TestCase):
+class TestCronExpressionSchema:
def test_should_deserialize(self):
instance = {"__type": "CronExpression", "value": "5 4 * * *"}
schema_instance = CronExpressionSchema()
@@ -87,7 +86,7 @@ class TestCronExpressionSchema(unittest.TestCase):
assert expected_instance == result
-class TestScheduleIntervalSchema(unittest.TestCase):
+class TestScheduleIntervalSchema:
def test_should_serialize_timedelta(self):
instance = datetime.timedelta(days=12)
schema_instance = ScheduleIntervalSchema()
diff --git a/tests/api_connexion/schemas/test_connection_schema.py
b/tests/api_connexion/schemas/test_connection_schema.py
index 531c86ee05..1383a68e09 100644
--- a/tests/api_connexion/schemas/test_connection_schema.py
+++ b/tests/api_connexion/schemas/test_connection_schema.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import re
-import unittest
import marshmallow
import pytest
@@ -34,12 +33,12 @@ from airflow.utils.session import create_session,
provide_session
from tests.test_utils.db import clear_db_connections
-class TestConnectionCollectionItemSchema(unittest.TestCase):
- def setUp(self) -> None:
+class TestConnectionCollectionItemSchema:
+ def setup_method(self) -> None:
with create_session() as session:
session.query(Connection).delete()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_connections()
@provide_session
@@ -106,12 +105,12 @@ class
TestConnectionCollectionItemSchema(unittest.TestCase):
connection_collection_item_schema.load(connection_dump_1)
-class TestConnectionCollectionSchema(unittest.TestCase):
- def setUp(self) -> None:
+class TestConnectionCollectionSchema:
+ def setup_method(self) -> None:
with create_session() as session:
session.query(Connection).delete()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_connections()
@provide_session
@@ -148,12 +147,12 @@ class TestConnectionCollectionSchema(unittest.TestCase):
}
-class TestConnectionSchema(unittest.TestCase):
- def setUp(self) -> None:
+class TestConnectionSchema:
+ def setup_method(self) -> None:
with create_session() as session:
session.query(Connection).delete()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_connections()
@provide_session
@@ -205,7 +204,7 @@ class TestConnectionSchema(unittest.TestCase):
}
-class TestConnectionTestSchema(unittest.TestCase):
+class TestConnectionTestSchema:
def test_response(self):
data = {
'status': True,
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py
b/tests/api_connexion/schemas/test_dag_run_schema.py
index ae10ba3b6f..395ce1874e 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -16,11 +16,8 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
from dateutil.parser import parse
-from parameterized import parameterized
from airflow.api_connexion.exceptions import BadRequest
from airflow.api_connexion.schemas.dag_run_schema import (
@@ -39,13 +36,13 @@ DEFAULT_TIME = "2020-06-09T13:59:56.336000+00:00"
SECOND_TIME = "2020-06-10T13:59:56.336000+00:00"
-class TestDAGRunBase(unittest.TestCase):
- def setUp(self) -> None:
+class TestDAGRunBase:
+ def setup_method(self) -> None:
clear_db_runs()
self.default_time = DEFAULT_TIME
self.second_time = SECOND_TIME
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_runs()
@@ -82,7 +79,8 @@ class TestDAGRunSchema(TestDAGRunBase):
"run_type": "manual",
}
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "serialized_dagrun, expected_result",
[
( # Conf not provided
{"dag_run_id": "my-dag-run", "execution_date": DEFAULT_TIME},
@@ -112,7 +110,7 @@ class TestDAGRunSchema(TestDAGRunBase):
"conf": {"start": "stop"},
},
),
- ]
+ ],
)
def test_deserialize(self, serialized_dagrun, expected_result):
result = dagrun_schema.load(serialized_dagrun)
diff --git a/tests/api_connexion/schemas/test_error_schema.py
b/tests/api_connexion/schemas/test_error_schema.py
index 38d5762bf7..ca150ac6f2 100644
--- a/tests/api_connexion/schemas/test_error_schema.py
+++ b/tests/api_connexion/schemas/test_error_schema.py
@@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.api_connexion.schemas.error_schema import (
ImportErrorCollection,
import_error_collection_schema,
@@ -29,12 +27,12 @@ from airflow.utils.session import provide_session
from tests.test_utils.db import clear_db_import_errors
-class TestErrorSchemaBase(unittest.TestCase):
- def setUp(self) -> None:
+class TestErrorSchemaBase:
+ def setup_method(self) -> None:
clear_db_import_errors()
self.timestamp = "2020-06-10T12:02:44"
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_import_errors()
diff --git a/tests/api_connexion/schemas/test_health_schema.py
b/tests/api_connexion/schemas/test_health_schema.py
index f1920e95c5..fe0c83e261 100644
--- a/tests/api_connexion/schemas/test_health_schema.py
+++ b/tests/api_connexion/schemas/test_health_schema.py
@@ -16,13 +16,11 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.api_connexion.schemas.health_schema import health_schema
-class TestHealthSchema(unittest.TestCase):
- def setUp(self):
+class TestHealthSchema:
+ def setup_method(self):
self.default_datetime = "2020-06-10T12:02:44+00:00"
def test_serialize(self):
diff --git a/tests/api_connexion/schemas/test_plugin_schema.py
b/tests/api_connexion/schemas/test_plugin_schema.py
index a0cdd00254..e4f8fe1388 100644
--- a/tests/api_connexion/schemas/test_plugin_schema.py
+++ b/tests/api_connexion/schemas/test_plugin_schema.py
@@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.api_connexion.schemas.plugin_schema import (
PluginCollection,
plugin_collection_schema,
@@ -26,8 +24,8 @@ from airflow.api_connexion.schemas.plugin_schema import (
from airflow.plugins_manager import AirflowPlugin
-class TestPluginBase(unittest.TestCase):
- def setUp(self) -> None:
+class TestPluginBase:
+ def setup_method(self) -> None:
self.mock_plugin = AirflowPlugin()
self.mock_plugin.name = "test_plugin"
diff --git a/tests/api_connexion/schemas/test_pool_schemas.py
b/tests/api_connexion/schemas/test_pool_schemas.py
index 48c0f8ee15..f0eb0c0a49 100644
--- a/tests/api_connexion/schemas/test_pool_schemas.py
+++ b/tests/api_connexion/schemas/test_pool_schemas.py
@@ -16,19 +16,17 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.api_connexion.schemas.pool_schema import PoolCollection,
pool_collection_schema, pool_schema
from airflow.models.pool import Pool
from airflow.utils.session import provide_session
from tests.test_utils.db import clear_db_pools
-class TestPoolSchema(unittest.TestCase):
- def setUp(self) -> None:
+class TestPoolSchema:
+ def setup_method(self) -> None:
clear_db_pools()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_pools()
@provide_session
@@ -56,11 +54,11 @@ class TestPoolSchema(unittest.TestCase):
assert not isinstance(deserialized_pool, Pool) # Checks if
load_instance is set to True
-class TestPoolCollectionSchema(unittest.TestCase):
- def setUp(self) -> None:
+class TestPoolCollectionSchema:
+ def setup_method(self) -> None:
clear_db_pools()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_pools()
def test_serialize(self):
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py
b/tests/api_connexion/schemas/test_task_instance_schema.py
index 4e01ca57c8..321664b4bc 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -17,11 +17,9 @@
from __future__ import annotations
import datetime as dt
-import unittest
import pytest
from marshmallow import ValidationError
-from parameterized import parameterized
from airflow.api_connexion.schemas.task_instance_schema import (
clear_task_instance_form,
@@ -150,70 +148,59 @@ class TestTaskInstanceSchema:
assert serialized_ti == expected_json
-class TestClearTaskInstanceFormSchema(unittest.TestCase):
- @parameterized.expand(
+class TestClearTaskInstanceFormSchema:
+ @pytest.mark.parametrize(
+ "payload",
[
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "only_failed": True,
- "only_running": True,
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "only_failed": True,
+ "only_running": True,
+ }
),
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "end_date": "2020-01-01T00:00:00+00:00",
- "start_date": "2020-01-02T00:00:00+00:00",
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "end_date": "2020-01-01T00:00:00+00:00",
+ "start_date": "2020-01-02T00:00:00+00:00",
+ }
),
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "task_ids": [],
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "task_ids": [],
+ }
),
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
- "start_date": "2022-08-03T00:00:00+00:00",
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
+ "start_date": "2022-08-03T00:00:00+00:00",
+ }
),
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
- "end_date": "2022-08-03T00:00:00+00:00",
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
+ "end_date": "2022-08-03T00:00:00+00:00",
+ }
),
(
- [
- {
- "dry_run": False,
- "reset_dag_runs": True,
- "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
- "end_date": "2022-08-04T00:00:00+00:00",
- "start_date": "2022-08-03T00:00:00+00:00",
- }
- ]
+ {
+ "dry_run": False,
+ "reset_dag_runs": True,
+ "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00",
+ "end_date": "2022-08-04T00:00:00+00:00",
+ "start_date": "2022-08-03T00:00:00+00:00",
+ }
),
- ]
+ ],
)
def test_validation_error(self, payload):
with pytest.raises(ValidationError):
@@ -246,14 +233,15 @@ class TestSetTaskInstanceStateFormSchema:
}
assert expected_result == result
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "override_data",
[
- ({"task_id": None},),
- ({"include_future": "foo"},),
- ({"execution_date": "NOW"},),
- ({"new_state": "INVALID_STATE"},),
- ({"execution_date": "2020-01-01T00:00:00+00:00", "dag_run_id":
"some-run-id"},),
- ]
+ {"task_id": None},
+ {"include_future": "foo"},
+ {"execution_date": "NOW"},
+ {"new_state": "INVALID_STATE"},
+ {"execution_date": "2020-01-01T00:00:00+00:00", "dag_run_id":
"some-run-id"},
+ ],
)
def test_validation_error(self, override_data):
self.current_input.update(override_data)
diff --git a/tests/api_connexion/schemas/test_version_schema.py
b/tests/api_connexion/schemas/test_version_schema.py
index 16ed1660f7..57db9d3dfa 100644
--- a/tests/api_connexion/schemas/test_version_schema.py
+++ b/tests/api_connexion/schemas/test_version_schema.py
@@ -16,21 +16,14 @@
# under the License.
from __future__ import annotations
-import unittest
-
-from parameterized import parameterized
+import pytest
from airflow.api_connexion.endpoints.version_endpoint import VersionInfo
from airflow.api_connexion.schemas.version_schema import version_info_schema
-class TestVersionInfoSchema(unittest.TestCase):
- @parameterized.expand(
- [
- ("GIT_COMMIT",),
- (None,),
- ]
- )
+class TestVersionInfoSchema:
+ @pytest.mark.parametrize("git_commit", ["GIT_COMMIT", None])
def test_serialize(self, git_commit):
version_info = VersionInfo("VERSION", git_commit)
current_data = version_info_schema.dump(version_info)
diff --git a/tests/api_connexion/test_parameters.py
b/tests/api_connexion/test_parameters.py
index a73a112d1d..a0c3b3a185 100644
--- a/tests/api_connexion/test_parameters.py
+++ b/tests/api_connexion/test_parameters.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
import pytest
@@ -34,8 +33,8 @@ from airflow.utils import timezone
from tests.test_utils.config import conf_vars
-class TestValidateIsTimezone(unittest.TestCase):
- def setUp(self) -> None:
+class TestValidateIsTimezone:
+ def setup_method(self) -> None:
from datetime import datetime
self.naive = datetime.now()
@@ -49,8 +48,8 @@ class TestValidateIsTimezone(unittest.TestCase):
assert validate_istimezone(self.timezoned) is None
-class TestDateTimeParser(unittest.TestCase):
- def setUp(self) -> None:
+class TestDateTimeParser:
+ def setup_method(self) -> None:
self.default_time = '2020-06-13T22:44:00+00:00'
self.default_time_2 = '2020-06-13T22:44:00Z'
@@ -72,7 +71,7 @@ class TestDateTimeParser(unittest.TestCase):
format_datetime(invalid_datetime)
-class TestMaximumPagelimit(unittest.TestCase):
+class TestMaximumPagelimit:
@conf_vars({("api", "maximum_page_limit"): "320"})
def test_maximum_limit_return_val(self):
limit = check_limit(300)
@@ -99,7 +98,7 @@ class TestMaximumPagelimit(unittest.TestCase):
check_limit(-1)
-class TestFormatParameters(unittest.TestCase):
+class TestFormatParameters:
def test_should_works_with_datetime_formatter(self):
decorator = format_parameters({"param_a": format_datetime})
endpoint = mock.MagicMock()
diff --git a/tests/cli/commands/test_celery_command.py
b/tests/cli/commands/test_celery_command.py
index 79f373588f..fee0162ba5 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from argparse import Namespace
from tempfile import NamedTemporaryFile
from unittest import mock
@@ -32,7 +31,7 @@ from airflow.configuration import conf
from tests.test_utils.config import conf_vars
-class TestWorkerPrecheck(unittest.TestCase):
+class TestWorkerPrecheck:
@mock.patch('airflow.settings.validate_session')
def test_error(self, mock_validate_session):
"""
@@ -65,9 +64,9 @@ class TestWorkerPrecheck(unittest.TestCase):
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
-class TestWorkerServeLogs(unittest.TestCase):
+class TestWorkerServeLogs:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('airflow.cli.commands.celery_command.celery_app')
@@ -94,9 +93,9 @@ class TestWorkerServeLogs(unittest.TestCase):
@pytest.mark.backend("mysql", "postgres")
-class TestCeleryStopCommand(unittest.TestCase):
+class TestCeleryStopCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@@ -176,9 +175,9 @@ class TestCeleryStopCommand(unittest.TestCase):
@pytest.mark.backend("mysql", "postgres")
-class TestWorkerStart(unittest.TestCase):
+class TestWorkerStart:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@@ -237,9 +236,9 @@ class TestWorkerStart(unittest.TestCase):
@pytest.mark.backend("mysql", "postgres")
-class TestWorkerFailure(unittest.TestCase):
+class TestWorkerFailure:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('airflow.cli.commands.celery_command.Process')
@@ -255,9 +254,9 @@ class TestWorkerFailure(unittest.TestCase):
@pytest.mark.backend("mysql", "postgres")
-class TestFlowerCommand(unittest.TestCase):
+class TestFlowerCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('airflow.cli.commands.celery_command.celery_app')
diff --git a/tests/cli/commands/test_cheat_sheet_command.py
b/tests/cli/commands/test_cheat_sheet_command.py
index 56e5f7707d..d3afbf03e7 100644
--- a/tests/cli/commands/test_cheat_sheet_command.py
+++ b/tests/cli/commands/test_cheat_sheet_command.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import contextlib
import io
-import unittest
from unittest import mock
from airflow.cli import cli_parser
@@ -89,9 +88,9 @@ airflow cmd_e cmd_g | Help text G
"""
-class TestCheatSheetCommand(unittest.TestCase):
+class TestCheatSheetCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('airflow.cli.cli_parser.airflow_commands', MOCK_COMMANDS)
diff --git a/tests/cli/commands/test_config_command.py
b/tests/cli/commands/test_config_command.py
index f93889be30..ae7895d0d6 100644
--- a/tests/cli/commands/test_config_command.py
+++ b/tests/cli/commands/test_config_command.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import contextlib
import io
-import unittest
from unittest import mock
import pytest
@@ -28,9 +27,9 @@ from airflow.cli.commands import config_command
from tests.test_utils.config import conf_vars
-class TestCliConfigList(unittest.TestCase):
+class TestCliConfigList:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.config_command.io.StringIO")
@@ -47,9 +46,9 @@ class TestCliConfigList(unittest.TestCase):
assert 'testkey = test_value' in temp_stdout.getvalue()
-class TestCliConfigGetValue(unittest.TestCase):
+class TestCliConfigGetValue:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@conf_vars({('core', 'test_key'): 'test_value'})
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index e380d7f16a..c748149cc9 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -21,7 +21,6 @@ import contextlib
import io
import os
import tempfile
-import unittest
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock
@@ -48,15 +47,15 @@ DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1),
timezone=timezone.utc)
# TODO: Check if tests needs side effects - locally there's missing DAG
-class TestCliDags(unittest.TestCase):
+class TestCliDags:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.dagbag = DagBag(include_examples=True)
cls.dagbag.sync_to_db()
cls.parser = cli_parser.get_parser()
@classmethod
- def tearDownClass(cls) -> None:
+ def teardown_class(cls) -> None:
clear_db_runs()
clear_db_dags()
diff --git a/tests/cli/commands/test_dag_processor_command.py
b/tests/cli/commands/test_dag_processor_command.py
index 23f9980cb6..8c42594d9f 100644
--- a/tests/cli/commands/test_dag_processor_command.py
+++ b/tests/cli/commands/test_dag_processor_command.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
import pytest
@@ -28,13 +27,13 @@ from airflow.configuration import conf
from tests.test_utils.config import conf_vars
-class TestDagProcessorCommand(unittest.TestCase):
+class TestDagProcessorCommand:
"""
Tests the CLI interface and that it correctly calls the DagProcessor
"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@conf_vars(
diff --git a/tests/cli/commands/test_info_command.py
b/tests/cli/commands/test_info_command.py
index 5812742951..6a638b78f7 100644
--- a/tests/cli/commands/test_info_command.py
+++ b/tests/cli/commands/test_info_command.py
@@ -21,10 +21,8 @@ import importlib
import io
import logging
import os
-import unittest
import pytest
-from parameterized import parameterized
from rich.console import Console
from airflow.cli import cli_parser
@@ -42,15 +40,16 @@ def capture_show_output(instance):
return capture.get()
-class TestPiiAnonymizer(unittest.TestCase):
- def setUp(self) -> None:
+class TestPiiAnonymizer:
+ def setup_method(self) -> None:
self.instance = info_command.PiiAnonymizer()
def test_should_remove_pii_from_path(self):
home_path = os.path.expanduser("~/airflow/config")
assert "${HOME}/airflow/config" ==
self.instance.process_path(home_path)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "before, after",
[
(
"postgresql+psycopg2://postgres:airflow@postgres/airflow",
@@ -68,7 +67,7 @@ class TestPiiAnonymizer(unittest.TestCase):
"postgresql+psycopg2://postgres/airflow",
"postgresql+psycopg2://postgres/airflow",
),
- ]
+ ],
)
def test_should_remove_pii_from_url(self, before, after):
assert after == self.instance.process_url(before)
@@ -77,7 +76,6 @@ class TestPiiAnonymizer(unittest.TestCase):
class TestAirflowInfo:
@classmethod
def setup_class(cls):
-
cls.parser = cli_parser.get_parser()
@classmethod
diff --git a/tests/cli/commands/test_jobs_command.py
b/tests/cli/commands/test_jobs_command.py
index f35d1fe8ab..3e97ea0f1e 100644
--- a/tests/cli/commands/test_jobs_command.py
+++ b/tests/cli/commands/test_jobs_command.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import contextlib
import io
-import unittest
import pytest
@@ -30,16 +29,16 @@ from airflow.utils.state import State
from tests.test_utils.db import clear_db_jobs
-class TestCliConfigList(unittest.TestCase):
+class TestCliConfigList:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
- def setUp(self) -> None:
+ def setup_method(self) -> None:
clear_db_jobs()
self.scheduler_job = None
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
if self.scheduler_job and self.scheduler_job.processor_agent:
self.scheduler_job.processor_agent.end()
clear_db_jobs()
@@ -54,7 +53,7 @@ class TestCliConfigList(unittest.TestCase):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
jobs_command.check(self.parser.parse_args(['jobs', 'check',
'--job-type', 'SchedulerJob']))
- self.assertIn("Found one alive job.", temp_stdout.getvalue())
+ assert "Found one alive job." in temp_stdout.getvalue()
def
test_should_report_success_for_one_working_scheduler_with_hostname(self):
with create_session() as session:
@@ -71,7 +70,7 @@ class TestCliConfigList(unittest.TestCase):
['jobs', 'check', '--job-type', 'SchedulerJob',
'--hostname', 'HOSTNAME']
)
)
- self.assertIn("Found one alive job.", temp_stdout.getvalue())
+ assert "Found one alive job." in temp_stdout.getvalue()
def test_should_report_success_for_ha_schedulers(self):
scheduler_jobs = []
@@ -90,7 +89,7 @@ class TestCliConfigList(unittest.TestCase):
['jobs', 'check', '--job-type', 'SchedulerJob', '--limit',
'100', '--allow-multiple']
)
)
- self.assertIn("Found 3 alive jobs.", temp_stdout.getvalue())
+ assert "Found 3 alive jobs." in temp_stdout.getvalue()
for scheduler_job in scheduler_jobs:
if scheduler_job.processor_agent:
scheduler_job.processor_agent.end()
diff --git a/tests/cli/commands/test_kerberos_command.py
b/tests/cli/commands/test_kerberos_command.py
index b64c6941fc..2acf43dd42 100644
--- a/tests/cli/commands/test_kerberos_command.py
+++ b/tests/cli/commands/test_kerberos_command.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.cli import cli_parser
@@ -24,9 +23,9 @@ from airflow.cli.commands import kerberos_command
from tests.test_utils.config import conf_vars
-class TestKerberosCommand(unittest.TestCase):
+class TestKerberosCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('airflow.cli.commands.kerberos_command.krb')
diff --git a/tests/cli/commands/test_kubernetes_command.py
b/tests/cli/commands/test_kubernetes_command.py
index 8a0045f27d..d30cbb8ce7 100644
--- a/tests/cli/commands/test_kubernetes_command.py
+++ b/tests/cli/commands/test_kubernetes_command.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import os
import tempfile
-import unittest
from unittest import mock
from unittest.mock import MagicMock, call
@@ -29,9 +28,9 @@ from airflow.cli import cli_parser
from airflow.cli.commands import kubernetes_command
-class TestGenerateDagYamlCommand(unittest.TestCase):
+class TestGenerateDagYamlCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
def test_generate_dag_yaml(self):
@@ -56,12 +55,12 @@ class TestGenerateDagYamlCommand(unittest.TestCase):
assert os.stat(out_dir + file_name).st_size > 0
-class TestCleanUpPodsCommand(unittest.TestCase):
+class TestCleanUpPodsCommand:
label_selector = ','.join(['dag_id', 'task_id', 'try_number',
'airflow_version'])
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch('kubernetes.client.CoreV1Api.delete_namespaced_pod')
diff --git a/tests/cli/commands/test_legacy_commands.py
b/tests/cli/commands/test_legacy_commands.py
index 9cf927588b..7db800f856 100644
--- a/tests/cli/commands/test_legacy_commands.py
+++ b/tests/cli/commands/test_legacy_commands.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import contextlib
import io
-import unittest
from argparse import ArgumentError
from unittest.mock import MagicMock
@@ -59,9 +58,9 @@ LEGACY_COMMANDS = [
]
-class TestCliDeprecatedCommandsValue(unittest.TestCase):
+class TestCliDeprecatedCommandsValue:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
def test_should_display_value(self):
diff --git a/tests/cli/commands/test_plugins_command.py
b/tests/cli/commands/test_plugins_command.py
index e3d4dc9e90..6b955eb9e4 100644
--- a/tests/cli/commands/test_plugins_command.py
+++ b/tests/cli/commands/test_plugins_command.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import io
import json
import textwrap
-import unittest
from contextlib import redirect_stdout
from airflow.cli import cli_parser
@@ -40,9 +39,9 @@ class TestPlugin(AirflowPlugin):
hooks = [PluginHook]
-class TestPluginsCommand(unittest.TestCase):
+class TestPluginsCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock_plugin_manager(plugins=[])
@@ -118,4 +117,4 @@ class TestPluginsCommand(unittest.TestCase):
test-plugin-cli |
tests.cli.commands.test_plugins_command.PluginHook
"""
)
- self.assertEqual(stdout, expected_output)
+ assert stdout == expected_output
diff --git a/tests/cli/commands/test_rotate_fernet_key_command.py
b/tests/cli/commands/test_rotate_fernet_key_command.py
index 2423857cda..36da0760f5 100644
--- a/tests/cli/commands/test_rotate_fernet_key_command.py
+++ b/tests/cli/commands/test_rotate_fernet_key_command.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from cryptography.fernet import Fernet
@@ -30,16 +29,16 @@ from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_connections, clear_db_variables
-class TestRotateFernetKeyCommand(unittest.TestCase):
+class TestRotateFernetKeyCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
- def setUp(self) -> None:
+ def setup_method(self) -> None:
clear_db_connections(add_default_connections_back=False)
clear_db_variables()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
clear_db_connections(add_default_connections_back=False)
clear_db_variables()
diff --git a/tests/cli/commands/test_scheduler_command.py
b/tests/cli/commands/test_scheduler_command.py
index a321e857b3..cb820f3583 100644
--- a/tests/cli/commands/test_scheduler_command.py
+++ b/tests/cli/commands/test_scheduler_command.py
@@ -17,12 +17,11 @@
# under the License.
from __future__ import annotations
-import unittest
from http.server import BaseHTTPRequestHandler
from unittest import mock
from unittest.mock import MagicMock
-from parameterized import parameterized
+import pytest
from airflow.cli import cli_parser
from airflow.cli.commands import scheduler_command
@@ -31,27 +30,28 @@ from airflow.utils.serve_logs import serve_logs
from tests.test_utils.config import conf_vars
-class TestSchedulerCommand(unittest.TestCase):
+class TestSchedulerCommand:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "executor, expect_serve_logs",
[
("CeleryExecutor", False),
("LocalExecutor", True),
("SequentialExecutor", True),
("KubernetesExecutor", False),
- ]
+ ],
)
@mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob")
@mock.patch("airflow.cli.commands.scheduler_command.Process")
def test_serve_logs_on_scheduler(
self,
- executor,
- expect_serve_logs,
mock_process,
mock_scheduler_job,
+ executor,
+ expect_serve_logs,
):
args = self.parser.parse_args(['scheduler'])
@@ -60,33 +60,23 @@ class TestSchedulerCommand(unittest.TestCase):
if expect_serve_logs:
mock_process.assert_has_calls([mock.call(target=serve_logs)])
else:
- with self.assertRaises(AssertionError):
+ with pytest.raises(AssertionError):
mock_process.assert_has_calls([mock.call(target=serve_logs)])
- @parameterized.expand(
- [
- ("LocalExecutor",),
- ("SequentialExecutor",),
- ]
- )
@mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob")
@mock.patch("airflow.cli.commands.scheduler_command.Process")
- def test_skip_serve_logs(self, executor, mock_process, mock_scheduler_job):
+ @pytest.mark.parametrize("executor", ["LocalExecutor",
"SequentialExecutor"])
+ def test_skip_serve_logs(self, mock_process, mock_scheduler_job, executor):
args = self.parser.parse_args(['scheduler', '--skip-serve-logs'])
with conf_vars({("core", "executor"): executor}):
scheduler_command.scheduler(args)
- with self.assertRaises(AssertionError):
+ with pytest.raises(AssertionError):
mock_process.assert_has_calls([mock.call(target=serve_logs)])
- @parameterized.expand(
- [
- ("LocalExecutor",),
- ("SequentialExecutor",),
- ]
- )
@mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob")
@mock.patch("airflow.cli.commands.scheduler_command.Process")
- def test_graceful_shutdown(self, executor, mock_process,
mock_scheduler_job):
+ @pytest.mark.parametrize("executor", ["LocalExecutor",
"SequentialExecutor"])
+ def test_graceful_shutdown(self, mock_process, mock_scheduler_job,
executor):
args = self.parser.parse_args(['scheduler'])
with conf_vars({("core", "executor"): executor}):
mock_scheduler_job.run.side_effect = Exception('Mock exception to
trigger runtime error')
@@ -116,7 +106,7 @@ class TestSchedulerCommand(unittest.TestCase):
):
args = self.parser.parse_args(['scheduler'])
scheduler_command.scheduler(args)
- with self.assertRaises(AssertionError):
+ with pytest.raises(AssertionError):
mock_process.assert_has_calls([mock.call(target=serve_health_check)])
@@ -131,8 +121,8 @@ class MockServer(HealthServer):
super().do_GET()
-class TestSchedulerHealthServer(unittest.TestCase):
- def setUp(self) -> None:
+class TestSchedulerHealthServer:
+ def setup_method(self) -> None:
self.mock_server = MockServer()
@mock.patch.object(BaseHTTPRequestHandler, "send_error")
diff --git a/tests/cli/commands/test_sync_perm_command.py
b/tests/cli/commands/test_sync_perm_command.py
index f567a84ea5..bcc7e60d72 100644
--- a/tests/cli/commands/test_sync_perm_command.py
+++ b/tests/cli/commands/test_sync_perm_command.py
@@ -17,16 +17,15 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.cli import cli_parser
from airflow.cli.commands import sync_perm_command
-class TestCliSyncPerm(unittest.TestCase):
+class TestCliSyncPerm:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.sync_perm_command.cached_app")
diff --git a/tests/cli/commands/test_triggerer_command.py
b/tests/cli/commands/test_triggerer_command.py
index 6edee751da..51086323c4 100644
--- a/tests/cli/commands/test_triggerer_command.py
+++ b/tests/cli/commands/test_triggerer_command.py
@@ -17,20 +17,19 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.cli import cli_parser
from airflow.cli.commands import triggerer_command
-class TestTriggererCommand(unittest.TestCase):
+class TestTriggererCommand:
"""
Tests the CLI interface and that it correctly calls the TriggererJob
"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.triggerer_command.TriggererJob")
diff --git a/tests/cli/commands/test_variable_command.py
b/tests/cli/commands/test_variable_command.py
index 673323a317..ee61c67e31 100644
--- a/tests/cli/commands/test_variable_command.py
+++ b/tests/cli/commands/test_variable_command.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import io
import os
import tempfile
-import unittest.mock
from contextlib import redirect_stdout
import pytest
@@ -32,16 +31,16 @@ from airflow.models import Variable
from tests.test_utils.db import clear_db_variables
-class TestCliVariables(unittest.TestCase):
+class TestCliVariables:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.dagbag = models.DagBag(include_examples=True)
cls.parser = cli_parser.get_parser()
- def setUp(self):
+ def setup_method(self):
clear_db_variables()
- def tearDown(self):
+ def teardown_method(self):
clear_db_variables()
def test_variables_set(self):
diff --git a/tests/cli/commands/test_version_command.py
b/tests/cli/commands/test_version_command.py
index ed9c655d4b..98a19010d4 100644
--- a/tests/cli/commands/test_version_command.py
+++ b/tests/cli/commands/test_version_command.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import io
-import unittest
from contextlib import redirect_stdout
import airflow.cli.commands.version_command
@@ -25,9 +24,9 @@ from airflow.cli import cli_parser
from airflow.version import version
-class TestCliVersion(unittest.TestCase):
+class TestCliVersion:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
def test_cli_version(self):
diff --git a/tests/cli/commands/test_webserver_command.py
b/tests/cli/commands/test_webserver_command.py
index 9aa8e97885..4a4e70ce65 100644
--- a/tests/cli/commands/test_webserver_command.py
+++ b/tests/cli/commands/test_webserver_command.py
@@ -21,7 +21,6 @@ import subprocess
import sys
import tempfile
import time
-import unittest
from unittest import mock
import psutil
@@ -35,8 +34,8 @@ from airflow.utils.cli import setup_locations
from tests.test_utils.config import conf_vars
-class TestGunicornMonitor(unittest.TestCase):
- def setUp(self) -> None:
+class TestGunicornMonitor:
+ def setup_method(self) -> None:
self.monitor = GunicornMonitor(
gunicorn_master_pid=1,
num_workers_expected=4,
@@ -127,7 +126,7 @@ class TestGunicornMonitor(unittest.TestCase):
assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5
-class TestGunicornMonitorGeneratePluginState(unittest.TestCase):
+class TestGunicornMonitorGeneratePluginState:
@staticmethod
def _prepare_test_file(filepath: str, size: int):
os.makedirs(os.path.dirname(filepath), exist_ok=True)
@@ -184,12 +183,12 @@ class
TestGunicornMonitorGeneratePluginState(unittest.TestCase):
assert 4 == len(state_d)
-class TestCLIGetNumReadyWorkersRunning(unittest.TestCase):
+class TestCLIGetNumReadyWorkersRunning:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parser = cli_parser.get_parser()
- def setUp(self):
+ def setup_method(self):
self.children = mock.MagicMock()
self.child = mock.MagicMock()
self.process = mock.MagicMock()
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index a3e8323b2e..26b917d096 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -23,10 +23,8 @@ import contextlib
import io
import re
from collections import Counter
-from unittest import TestCase
import pytest
-from parameterized import parameterized
from airflow.cli import cli_parser
from tests.test_utils.config import conf_vars
@@ -39,7 +37,7 @@ LEGAL_SHORT_OPTION_PATTERN = re.compile("^-[a-zA-z]$")
cli_args = {k: v for k, v in cli_parser.__dict__.items() if
k.startswith("ARG_")}
-class TestCli(TestCase):
+class TestCli:
def test_arg_option_long_only(self):
"""
Test if the name of cli.args long option valid
@@ -151,11 +149,11 @@ class TestCli(TestCase):
parser = cli_parser.get_parser(dag_parser=True)
with contextlib.redirect_stdout(io.StringIO()) as stdout:
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args(['--help'])
stdout = stdout.getvalue()
- self.assertIn("Commands", stdout)
- self.assertIn("Groups", stdout)
+ assert "Commands" in stdout
+ assert "Groups" in stdout
def test_should_display_help(self):
parser = cli_parser.get_parser()
@@ -186,7 +184,7 @@ class TestCli(TestCase):
)
]
for cmd_args in all_command_as_args:
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args([*cmd_args, '--help'])
def test_positive_int(self):
@@ -202,7 +200,7 @@ class TestCli(TestCase):
io.StringIO()
) as stderr:
parser = cli_parser.get_parser()
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args(['celery'])
stderr = stderr.getvalue()
assert (
@@ -211,18 +209,19 @@ class TestCli(TestCase):
"your current executor: SequentialExecutor, subclassed from:
BaseExecutor, see help above."
) in stderr
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "executor",
[
"CeleryExecutor",
"CeleryKubernetesExecutor",
"custom_executor.CustomCeleryExecutor",
"custom_executor.CustomCeleryKubernetesExecutor",
- ]
+ ],
)
def test_dag_parser_celery_command_accept_celery_executor(self, executor):
with conf_vars({('core', 'executor'): executor}),
contextlib.redirect_stderr(io.StringIO()) as stderr:
parser = cli_parser.get_parser()
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args(['celery'])
stderr = stderr.getvalue()
assert (
diff --git a/tests/core/test_config_templates.py
b/tests/core/test_config_templates.py
index 1293a24dda..ecef25a4a4 100644
--- a/tests/core/test_config_templates.py
+++ b/tests/core/test_config_templates.py
@@ -18,9 +18,8 @@ from __future__ import annotations
import configparser
import os
-import unittest
-from parameterized import parameterized
+import pytest
from tests.test_utils import AIRFLOW_MAIN_FOLDER
@@ -70,19 +69,15 @@ DEFAULT_TEST_SECTIONS = [
]
-class TestAirflowCfg(unittest.TestCase):
- @parameterized.expand(
- [
- ("default_airflow.cfg",),
- ("default_test.cfg",),
- ]
- )
+class TestAirflowCfg:
+ @pytest.mark.parametrize("filename", ["default_airflow.cfg",
"default_test.cfg"])
def test_should_be_ascii_file(self, filename: str):
with open(os.path.join(CONFIG_TEMPLATES_FOLDER, filename), "rb") as f:
content = f.read().decode("ascii")
assert content
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "filename, expected_sections",
[
(
"default_airflow.cfg",
@@ -92,7 +87,7 @@ class TestAirflowCfg(unittest.TestCase):
"default_test.cfg",
DEFAULT_TEST_SECTIONS,
),
- ]
+ ],
)
def test_should_be_ini_file(self, filename: str, expected_sections):
filepath = os.path.join(CONFIG_TEMPLATES_FOLDER, filename)
diff --git a/tests/core/test_logging_config.py
b/tests/core/test_logging_config.py
index 7cd51e78a8..4d69ac59de 100644
--- a/tests/core/test_logging_config.py
+++ b/tests/core/test_logging_config.py
@@ -24,11 +24,9 @@ import os
import pathlib
import sys
import tempfile
-import unittest
from unittest.mock import patch
import pytest
-from parameterized import parameterized
from airflow.configuration import conf
from tests.test_utils.config import conf_vars
@@ -170,12 +168,12 @@ def settings_context(content, directory=None,
name='LOGGING_CONFIG'):
sys.path.remove(settings_root)
-class TestLoggingSettings(unittest.TestCase):
+class TestLoggingSettings:
# Make sure that the configure_logging is not cached
- def setUp(self):
+ def setup_method(self):
self.old_modules = dict(sys.modules)
- def tearDown(self):
+ def teardown_method(self):
# Remove any new modules imported during the test run. This lets us
# import the same source files for more than one test.
from airflow.config_templates import airflow_local_settings
@@ -281,7 +279,8 @@ class TestLoggingSettings(unittest.TestCase):
logger = logging.getLogger('airflow.task')
assert isinstance(logger.handlers[0], WasbTaskHandler)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "remote_base_log_folder, log_group_arn",
[
(
'cloudwatch://arn:aws:logs:aaaa:bbbbb:log-group:ccccc',
@@ -295,7 +294,7 @@ class TestLoggingSettings(unittest.TestCase):
'cloudwatch://arn:aws:logs:aaaa:bbbbb:log-group:/aws/ecs/ccccc',
'arn:aws:logs:aaaa:bbbbb:log-group:/aws/ecs/ccccc',
),
- ]
+ ],
)
def test_log_group_arns_remote_logging_with_cloudwatch_handler(
self, remote_base_log_folder, log_group_arn
diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py
index 54ae31eb55..2292795343 100644
--- a/tests/core/test_settings.py
+++ b/tests/core/test_settings.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import os
import sys
import tempfile
-import unittest
+from unittest import mock
from unittest.mock import MagicMock, call
import pytest
@@ -78,32 +78,32 @@ class SettingsContext:
sys.path.remove(self.settings_root)
-class TestLocalSettings(unittest.TestCase):
+class TestLocalSettings:
# Make sure that the configure_logging is not cached
- def setUp(self):
+ def setup_method(self):
self.old_modules = dict(sys.modules)
- def tearDown(self):
+ def teardown_method(self):
# Remove any new modules imported during the test run. This lets us
# import the same source files for more than one test.
for mod in [m for m in sys.modules if m not in self.old_modules]:
del sys.modules[mod]
- @unittest.mock.patch("airflow.settings.import_local_settings")
- @unittest.mock.patch("airflow.settings.prepare_syspath")
+ @mock.patch("airflow.settings.import_local_settings")
+ @mock.patch("airflow.settings.prepare_syspath")
def test_initialize_order(self, prepare_syspath, import_local_settings):
"""
Tests that import_local_settings is called after prepare_classpath
"""
- mock = unittest.mock.Mock()
- mock.attach_mock(prepare_syspath, "prepare_syspath")
- mock.attach_mock(import_local_settings, "import_local_settings")
+ mock_local_settings = mock.Mock()
+ mock_local_settings.attach_mock(prepare_syspath, "prepare_syspath")
+ mock_local_settings.attach_mock(import_local_settings,
"import_local_settings")
import airflow.settings
airflow.settings.initialize()
- mock.assert_has_calls([call.prepare_syspath(),
call.import_local_settings()])
+ mock_local_settings.assert_has_calls([call.prepare_syspath(),
call.import_local_settings()])
def test_import_with_dunder_all_not_specified(self):
"""
@@ -133,7 +133,7 @@ class TestLocalSettings(unittest.TestCase):
assert task_instance.run_as_user == "myself"
- @unittest.mock.patch("airflow.settings.log.debug")
+ @mock.patch("airflow.settings.log.debug")
def test_import_local_settings_without_syspath(self, log_mock):
"""
Tests that an ImportError is raised in import_local_settings
@@ -186,7 +186,7 @@ class TestLocalSettings(unittest.TestCase):
settings.task_must_have_owners(task_instance)
-class TestUpdatedConfigNames(unittest.TestCase):
+class TestUpdatedConfigNames:
@conf_vars(
{("webserver", "session_lifetime_days"): '5', ("webserver",
"session_lifetime_minutes"): '43200'}
)
diff --git a/tests/core/test_sqlalchemy_config.py
b/tests/core/test_sqlalchemy_config.py
index 12e774a0f4..8000edb106 100644
--- a/tests/core/test_sqlalchemy_config.py
+++ b/tests/core/test_sqlalchemy_config.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import patch
import pytest
@@ -30,14 +29,14 @@ from tests.test_utils.config import conf_vars
SQL_ALCHEMY_CONNECT_ARGS = {'test': 43503, 'dict': {'is': 1, 'supported':
'too'}}
-class TestSqlAlchemySettings(unittest.TestCase):
- def setUp(self):
+class TestSqlAlchemySettings:
+ def setup_method(self):
self.old_engine = settings.engine
self.old_session = settings.Session
self.old_conn = settings.SQL_ALCHEMY_CONN
settings.SQL_ALCHEMY_CONN =
"mysql+foobar://user:pass@host/dbname?inline=param&another=param"
- def tearDown(self):
+ def teardown_method(self):
settings.engine = self.old_engine
settings.Session = self.old_session
settings.SQL_ALCHEMY_CONN = self.old_conn
diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py
index eae5a36e35..3373bbfac7 100644
--- a/tests/core/test_stats.py
+++ b/tests/core/test_stats.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import importlib
import re
-import unittest
from unittest import mock
from unittest.mock import Mock
@@ -46,8 +45,8 @@ class InvalidCustomStatsd:
pass
-class TestStats(unittest.TestCase):
- def setUp(self):
+class TestStats:
+ def setup_method(self):
self.statsd_client = Mock(spec=statsd.StatsClient)
self.stats = SafeStatsdLogger(self.statsd_client)
@@ -128,8 +127,8 @@ class TestStats(unittest.TestCase):
importlib.reload(airflow.stats)
-class TestDogStats(unittest.TestCase):
- def setUp(self):
+class TestDogStats:
+ def setup_method(self):
pytest.importorskip('datadog')
from datadog import DogStatsd
@@ -167,6 +166,7 @@ class TestDogStats(unittest.TestCase):
)
def
test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self):
+ # ToDo: Figure out why it identical to
test_does_send_stats_using_dogstatsd_when_dogstatsd_on
self.dogstatsd.incr("empty_key")
self.dogstatsd_client.increment.assert_called_once_with(
metric='empty_key', sample_rate=1, tags=[], value=1
@@ -222,8 +222,8 @@ class TestDogStats(unittest.TestCase):
importlib.reload(airflow.stats)
-class TestStatsWithAllowList(unittest.TestCase):
- def setUp(self):
+class TestStatsWithAllowList:
+ def setup_method(self):
self.statsd_client = Mock(spec=statsd.StatsClient)
self.stats = SafeStatsdLogger(self.statsd_client,
AllowListValidator("stats_one, stats_two"))
@@ -240,8 +240,8 @@ class TestStatsWithAllowList(unittest.TestCase):
self.statsd_client.assert_not_called()
-class TestDogStatsWithAllowList(unittest.TestCase):
- def setUp(self):
+class TestDogStatsWithAllowList:
+ def setup_method(self):
pytest.importorskip('datadog')
from datadog import DogStatsd
@@ -273,7 +273,7 @@ def always_valid(stat_name):
return stat_name
-class TestCustomStatsName(unittest.TestCase):
+class TestCustomStatsName:
@conf_vars(
{
('metrics', 'statsd_on'): 'True',
@@ -324,6 +324,6 @@ class TestCustomStatsName(unittest.TestCase):
metric='empty_key', sample_rate=1, tags=[], value=1
)
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
# To avoid side-effect
importlib.reload(airflow.stats)
diff --git a/tests/dag_processing/test_manager.py
b/tests/dag_processing/test_manager.py
index 458ab18f10..b4446a0e60 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -25,7 +25,6 @@ import random
import socket
import sys
import threading
-import unittest
from datetime import datetime, timedelta
from logging.config import dictConfig
from tempfile import TemporaryDirectory
@@ -1059,12 +1058,12 @@ class TestDagFileProcessorManager:
assert manager._callback_to_execute[dag1_req1.full_filepath] ==
[dag1_req1, dag1_sla1, dag1_req2]
-class TestDagFileProcessorAgent(unittest.TestCase):
- def setUp(self):
+class TestDagFileProcessorAgent:
+ def setup_method(self):
# Make sure that the configure_logging is not cached
self.old_modules = dict(sys.modules)
- def tearDown(self):
+ def teardown_method(self):
# Remove any new modules imported during the test run. This lets us
# import the same source files for more than one test.
remove_list = []
diff --git a/tests/executors/test_dask_executor.py
b/tests/executors/test_dask_executor.py
index 51c1256993..df83e354a6 100644
--- a/tests/executors/test_dask_executor.py
+++ b/tests/executors/test_dask_executor.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from unittest import mock
@@ -55,7 +54,7 @@ skip_dask_tests = False
@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it
needs testing from Dask team")
-class TestBaseDask(unittest.TestCase):
+class TestBaseDask:
def assert_tasks_on_executor(self, executor, timeout_executor=120):
# start the executor
@@ -87,7 +86,7 @@ class TestBaseDask(unittest.TestCase):
@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it
needs testing from Dask team")
class TestDaskExecutor(TestBaseDask):
- def setUp(self):
+ def setup_method(self):
self.dagbag = DagBag(include_examples=True)
self.cluster = LocalCluster()
@@ -110,7 +109,7 @@ class TestDaskExecutor(TestBaseDask):
)
job.run()
- def tearDown(self):
+ def teardown_method(self):
self.cluster.close(timeout=5)
@@ -118,7 +117,7 @@ class TestDaskExecutor(TestBaseDask):
skip_tls_tests, reason="The tests are skipped because distributed
framework could not be imported"
)
class TestDaskExecutorTLS(TestBaseDask):
- def setUp(self):
+ def setup_method(self):
self.dagbag = DagBag(include_examples=True)
@conf_vars(
@@ -160,13 +159,13 @@ class TestDaskExecutorTLS(TestBaseDask):
@pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it
needs testing from Dask team")
-class TestDaskExecutorQueue(unittest.TestCase):
+class TestDaskExecutorQueue:
def test_dask_queues_no_resources(self):
self.cluster = LocalCluster()
executor = DaskExecutor(cluster_address=self.cluster.scheduler_address)
executor.start()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
executor.execute_async(key='success', command=SUCCESS_COMMAND,
queue='queue1')
def test_dask_queues_not_available(self):
@@ -174,7 +173,7 @@ class TestDaskExecutorQueue(unittest.TestCase):
executor = DaskExecutor(cluster_address=self.cluster.scheduler_address)
executor.start()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
# resource 'queue2' doesn't exist on cluster
executor.execute_async(key='success', command=SUCCESS_COMMAND,
queue='queue2')
@@ -219,5 +218,5 @@ class TestDaskExecutorQueue(unittest.TestCase):
assert success_future.done()
assert success_future.exception() is None
- def tearDown(self):
+ def teardown_method(self):
self.cluster.close(timeout=5)
diff --git a/tests/executors/test_executor_loader.py
b/tests/executors/test_executor_loader.py
index fec7391807..180e7b961c 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -16,10 +16,9 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
-from parameterized import parameterized
+import pytest
from airflow import plugins_manager
from airflow.executors.executor_loader import ExecutorLoader
@@ -38,21 +37,22 @@ class FakePlugin(plugins_manager.AirflowPlugin):
executors = [FakeExecutor]
-class TestExecutorLoader(unittest.TestCase):
- def setUp(self) -> None:
+class TestExecutorLoader:
+ def setup_method(self) -> None:
ExecutorLoader._default_executor = None
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
ExecutorLoader._default_executor = None
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "executor_name",
[
- ("CeleryExecutor",),
- ("CeleryKubernetesExecutor",),
- ("DebugExecutor",),
- ("KubernetesExecutor",),
- ("LocalExecutor",),
- ]
+ "CeleryExecutor",
+ "CeleryKubernetesExecutor",
+ "DebugExecutor",
+ "KubernetesExecutor",
+ "LocalExecutor",
+ ],
)
def test_should_support_executor_from_core(self, executor_name):
with conf_vars({("core", "executor"): executor_name}):
diff --git a/tests/executors/test_kubernetes_executor.py
b/tests/executors/test_kubernetes_executor.py
index 574bd0afad..48843cd098 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -55,7 +55,7 @@ except ImportError:
AirflowKubernetesScheduler = None # type: ignore
-class TestAirflowKubernetesScheduler(unittest.TestCase):
+class TestAirflowKubernetesScheduler:
@staticmethod
def _gen_random_string(seed, str_len):
char_list = []
@@ -895,8 +895,8 @@ class TestKubernetesExecutor:
assert mock_kube_client.list_namespaced_pod.call_count == 0
-class TestKubernetesJobWatcher(unittest.TestCase):
- def setUp(self):
+class TestKubernetesJobWatcher:
+ def setup_method(self):
self.watcher = KubernetesJobWatcher(
namespace="airflow",
multi_namespace_mode=False,
@@ -1009,12 +1009,12 @@ class TestKubernetesJobWatcher(unittest.TestCase):
self.pod.status.phase = 'Pending'
raw_object = {"code": 422, "message": message, "reason": "Test"}
self.events.append({"type": "ERROR", "object": self.pod, "raw_object":
raw_object})
- with self.assertRaises(AirflowException) as e:
- self._run()
- assert str(e.exception) == (
- f"Kubernetes failure for {raw_object['reason']} "
- f"with code {raw_object['code']} and message:
{raw_object['message']}"
+ error_message = (
+ fr"Kubernetes failure for {raw_object['reason']} "
+ fr"with code {raw_object['code']} and message:
{raw_object['message']}"
)
+ with pytest.raises(AirflowException, match=error_message):
+ self._run()
def test_recover_from_resource_too_old(self):
# too old resource
diff --git a/tests/executors/test_local_executor.py
b/tests/executors/test_local_executor.py
index a5e8bbbaec..cf7f37b1dd 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import datetime
import subprocess
-import unittest
from unittest import mock
from airflow import settings
@@ -28,7 +27,7 @@ from airflow.executors.local_executor import LocalExecutor
from airflow.utils.state import State
-class TestLocalExecutor(unittest.TestCase):
+class TestLocalExecutor:
TEST_SUCCESS_COMMANDS = 5
diff --git a/tests/executors/test_sequential_executor.py
b/tests/executors/test_sequential_executor.py
index 0e016dbf92..e52281ff80 100644
--- a/tests/executors/test_sequential_executor.py
+++ b/tests/executors/test_sequential_executor.py
@@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.executors.sequential_executor import SequentialExecutor
-class TestSequentialExecutor(unittest.TestCase):
+class TestSequentialExecutor:
@mock.patch('airflow.executors.sequential_executor.SequentialExecutor.sync')
@mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
@mock.patch('airflow.executors.base_executor.Stats.gauge')
diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py
index 56f010c733..c3df0e5e20 100644
--- a/tests/hooks/test_subprocess.py
+++ b/tests/hooks/test_subprocess.py
@@ -17,14 +17,13 @@
# under the License.
from __future__ import annotations
-import unittest
from pathlib import Path
from subprocess import PIPE, STDOUT
from tempfile import TemporaryDirectory
from unittest import mock
from unittest.mock import MagicMock
-from parameterized import parameterized
+import pytest
from airflow.hooks.subprocess import SubprocessHook
@@ -32,15 +31,17 @@ OS_ENV_KEY = 'SUBPROCESS_ENV_TEST'
OS_ENV_VAL = 'this-is-from-os-environ'
-class TestSubprocessHook(unittest.TestCase):
- @parameterized.expand(
+class TestSubprocessHook:
+ @pytest.mark.parametrize(
+ "env,expected",
[
- ('with env', {'ABC': '123', 'AAA': '456'}, {'ABC': '123', 'AAA':
'456', OS_ENV_KEY: ''}),
- ('empty env', {}, {OS_ENV_KEY: ''}),
- ('no env', None, {OS_ENV_KEY: OS_ENV_VAL}),
- ]
+ ({"ABC": "123", "AAA": "456"}, {"ABC": "123", "AAA": "456",
OS_ENV_KEY: ""}),
+ ({}, {OS_ENV_KEY: ""}),
+ (None, {OS_ENV_KEY: OS_ENV_VAL}),
+ ],
+ ids=["with env", "empty env", "no env"],
)
- def test_env(self, name, env, expected):
+ def test_env(self, env, expected):
"""
Test that env variables are exported correctly to the command
environment.
When ``env`` is ``None``, ``os.environ`` should be passed to ``Popen``.
@@ -63,13 +64,14 @@ class TestSubprocessHook(unittest.TestCase):
actual = dict([x.split('=') for x in
tmp_file.read_text().splitlines()])
assert actual == expected
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "val,expected",
[
('test-val', 'test-val'),
('test-val\ntest-val\n', ''),
('test-val\ntest-val', 'test-val'),
('', ''),
- ]
+ ],
)
def test_return_value(self, val, expected):
hook = SubprocessHook()
diff --git a/tests/kubernetes/models/test_secret.py
b/tests/kubernetes/models/test_secret.py
index 882b4850aa..90636ae4f9 100644
--- a/tests/kubernetes/models/test_secret.py
+++ b/tests/kubernetes/models/test_secret.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import sys
-import unittest
import uuid
from unittest import mock
@@ -28,7 +27,7 @@ from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
-class TestSecret(unittest.TestCase):
+class TestSecret:
def test_to_env_secret(self):
secret = Secret('env', 'name', 'secret', 'key')
assert secret.to_env_secret() == k8s.V1EnvVar(
diff --git a/tests/kubernetes/test_client.py b/tests/kubernetes/test_client.py
index 2f702adf2f..13fa8a3417 100644
--- a/tests/kubernetes/test_client.py
+++ b/tests/kubernetes/test_client.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import socket
-import unittest
from unittest import mock
from kubernetes.client import Configuration
@@ -26,7 +25,7 @@ from urllib3.connection import HTTPConnection, HTTPSConnection
from airflow.kubernetes.kube_client import _disable_verify_ssl,
_enable_tcp_keepalive, get_kube_client
-class TestClient(unittest.TestCase):
+class TestClient:
@mock.patch('airflow.kubernetes.kube_client.config')
def test_load_cluster_config(self, config):
get_kube_client(in_cluster=True)
@@ -50,7 +49,7 @@ class TestClient(unittest.TestCase):
configuration = Configuration.get_default_copy()
else:
configuration = Configuration()
- self.assertFalse(configuration.verify_ssl)
+ assert not configuration.verify_ssl
def test_enable_tcp_keepalive(self):
socket_options = [
@@ -69,7 +68,7 @@ class TestClient(unittest.TestCase):
def test_disable_verify_ssl(self):
configuration = Configuration()
- self.assertTrue(configuration.verify_ssl)
+ assert configuration.verify_ssl
_disable_verify_ssl()
@@ -78,4 +77,4 @@ class TestClient(unittest.TestCase):
configuration = Configuration.get_default_copy()
else:
configuration = Configuration()
- self.assertFalse(configuration.verify_ssl)
+ assert not configuration.verify_ssl
diff --git a/tests/macros/test_hive.py b/tests/macros/test_hive.py
index c0a47c794a..f231724620 100644
--- a/tests/macros/test_hive.py
+++ b/tests/macros/test_hive.py
@@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime
from airflow.macros import hive
-class TestHive(unittest.TestCase):
+class TestHive:
def test_closest_ds_partition(self):
date1 = datetime.strptime('2017-04-24', '%Y-%m-%d')
date2 = datetime.strptime('2017-04-25', '%Y-%m-%d')
diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py
index fe53601a3e..6f3c5d64cd 100644
--- a/tests/models/test_dagcode.py
+++ b/tests/models/test_dagcode.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from unittest.mock import patch
@@ -39,13 +38,13 @@ def make_example_dags(module):
return dagbag.dags
-class TestDagCode(unittest.TestCase):
+class TestDagCode:
"""Unit tests for DagCode."""
- def setUp(self):
+ def setup_method(self):
clear_db_dag_code()
- def tearDown(self):
+ def teardown_method(self):
clear_db_dag_code()
def _write_two_example_dags(self):
diff --git a/tests/models/test_param.py b/tests/models/test_param.py
index fcfcda4f71..b6b565490e 100644
--- a/tests/models/test_param.py
+++ b/tests/models/test_param.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from contextlib import nullcontext
import pytest
@@ -29,7 +28,7 @@ from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
-class TestParam(unittest.TestCase):
+class TestParam:
def test_param_without_schema(self):
p = Param('test')
assert p.resolve() == 'test'
@@ -47,7 +46,6 @@ class TestParam(unittest.TestCase):
assert p.resolve() is None
assert p.resolve(None) is None
- p = Param(type="null")
p = Param(None, type='null')
assert p.resolve() is None
assert p.resolve(None) is None
diff --git a/tests/operators/test_branch_operator.py
b/tests/operators/test_branch_operator.py
index efc05a647e..64081fbc39 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import datetime
-import unittest
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.operators.branch import BaseBranchOperator
@@ -42,16 +41,14 @@ class ChooseBranchOneTwo(BaseBranchOperator):
return ['branch_1', 'branch_2']
-class TestBranchOperator(unittest.TestCase):
+class TestBranchOperator:
@classmethod
- def setUpClass(cls):
- super().setUpClass()
-
+ def setup_class(cls):
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
- def setUp(self):
+ def setup_method(self):
self.dag = DAG(
'branch_operator_test',
default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
@@ -63,9 +60,7 @@ class TestBranchOperator(unittest.TestCase):
self.branch_3 = None
self.branch_op = None
- def tearDown(self):
- super().tearDown()
-
+ def teardown_method(self):
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
diff --git a/tests/operators/test_trigger_dagrun.py
b/tests/operators/test_trigger_dagrun.py
index d9b88fedb7..b396caf43f 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import pathlib
import tempfile
from datetime import datetime
-from unittest import TestCase, mock
+from unittest import mock
import pytest
@@ -49,8 +49,8 @@ DAG_SCRIPT = (
).format(dag_id=TRIGGERED_DAG_ID)
-class TestDagRunOperator(TestCase):
- def setUp(self):
+class TestDagRunOperator:
+ def setup_method(self):
# Airflow relies on reading the DAG from disk when triggering it.
# Therefore write a temp file holding the DAG to trigger.
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
@@ -67,7 +67,7 @@ class TestDagRunOperator(TestCase):
dagbag.bag_dag(self.dag, root_dag=self.dag)
dagbag.sync_to_db()
- def tearDown(self):
+ def teardown_method(self):
"""Cleanup state after testing in DB."""
with create_session() as session:
session.query(Log).filter(Log.dag_id ==
TEST_DAG_ID).delete(synchronize_session=False)
diff --git a/tests/operators/test_weekday.py b/tests/operators/test_weekday.py
index bfda7e83fd..45ebe7bc8d 100644
--- a/tests/operators/test_weekday.py
+++ b/tests/operators/test_weekday.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import datetime
-import unittest
import pytest
from freezegun import freeze_time
@@ -35,21 +34,36 @@ from airflow.utils.weekday import WeekDay
DEFAULT_DATE = timezone.datetime(2020, 2, 5) # Wednesday
INTERVAL = datetime.timedelta(hours=12)
-
-
-class TestBranchDayOfWeekOperator(unittest.TestCase):
+TEST_CASE_BRANCH_FOLLOW_TRUE = {
+ "with-string": "Monday",
+ "with-enum": WeekDay.MONDAY,
+ "with-enum-set": {WeekDay.MONDAY},
+ "with-enum-list": [WeekDay.MONDAY],
+ "with-enum-dict": {WeekDay.MONDAY: "some_value"},
+ "with-enum-set-2-items": {WeekDay.MONDAY, WeekDay.FRIDAY},
+ "with-enum-list-2-items": [WeekDay.MONDAY, WeekDay.FRIDAY],
+ "with-enum-dict-2-items": {WeekDay.MONDAY: "some_value", WeekDay.FRIDAY:
"some_value_2"},
+ "with-string-set": {"Monday"},
+ "with-string-set-2-items": {"Monday", "Friday"},
+ "with-set-mix-types": {"Monday", WeekDay.FRIDAY},
+ "with-list-mix-types": ["Monday", WeekDay.FRIDAY],
+ "with-dict-mix-types": {"Monday": "some_value", WeekDay.FRIDAY:
"some_value_2"},
+}
+
+
+class TestBranchDayOfWeekOperator:
"""
Tests for BranchDayOfWeekOperator
"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
session.query(XCom).delete()
- def setUp(self):
+ def setup_method(self):
self.dag = DAG(
"branch_day_of_week_operator_test",
start_date=DEFAULT_DATE,
@@ -59,7 +73,7 @@ class TestBranchDayOfWeekOperator(unittest.TestCase):
self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
self.branch_3 = None
- def tearDown(self):
+ def teardown_method(self):
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
@@ -74,31 +88,14 @@ class TestBranchDayOfWeekOperator(unittest.TestCase):
except KeyError:
raise ValueError(f'Invalid task id {ti.task_id} found!')
else:
- self.assertEqual(
- ti.state,
- expected_state,
- f"Task {ti.task_id} has state {ti.state} instead of
expected {expected_state}",
- )
+ assert_msg = f"Task {ti.task_id} has state {ti.state} instead
of expected {expected_state}"
+ assert ti.state == expected_state, assert_msg
- @parameterized.expand(
- [
- ("with-string", "Monday"),
- ("with-enum", WeekDay.MONDAY),
- ("with-enum-set", {WeekDay.MONDAY}),
- ("with-enum-list", [WeekDay.MONDAY]),
- ("with-enum-dict", {WeekDay.MONDAY: "some_value"}),
- ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}),
- ("with-enum-list-2-items", [WeekDay.MONDAY, WeekDay.FRIDAY]),
- ("with-enum-dict-2-items", {WeekDay.MONDAY: "some_value",
WeekDay.FRIDAY: "some_value_2"}),
- ("with-string-set", {"Monday"}),
- ("with-string-set-2-items", {"Monday", "Friday"}),
- ("with-set-mix-types", {"Monday", WeekDay.FRIDAY}),
- ("with-list-mix-types", ["Monday", WeekDay.FRIDAY]),
- ("with-dict-mix-types", {"Monday": "some_value", WeekDay.FRIDAY:
"some_value_2"}),
- ]
+ @pytest.mark.parametrize(
+ "weekday", TEST_CASE_BRANCH_FOLLOW_TRUE.values(),
ids=TEST_CASE_BRANCH_FOLLOW_TRUE.keys()
)
@freeze_time("2021-01-25") # Monday
- def test_branch_follow_true(self, _, weekday):
+ def test_branch_follow_true(self, weekday):
"""Checks if BranchDayOfWeekOperator follows true branch"""
print(datetime.datetime.now())
branch_op = BranchDayOfWeekOperator(
@@ -205,7 +202,7 @@ class TestBranchDayOfWeekOperator(unittest.TestCase):
def test_branch_with_no_weekday(self):
"""Check if BranchDayOfWeekOperator raises exception on missing
weekday"""
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
BranchDayOfWeekOperator(
task_id="make_choice",
follow_task_ids_if_true="branch_1",
diff --git a/tests/plugins/test_plugin_ignore.py
b/tests/plugins/test_plugin_ignore.py
index 3db8f44af3..8bff057930 100644
--- a/tests/plugins/test_plugin_ignore.py
+++ b/tests/plugins/test_plugin_ignore.py
@@ -20,19 +20,18 @@ from __future__ import annotations
import os
import shutil
import tempfile
-import unittest
from unittest.mock import patch
from airflow import settings
from airflow.utils.file import find_path_from_directory
-class TestIgnorePluginFile(unittest.TestCase):
+class TestIgnorePluginFile:
"""
Test that the .airflowignore work and whether the file is properly ignored.
"""
- def setUp(self):
+ def setup_method(self):
"""
Make tmp folder and files that should be ignored. And set base path.
"""
@@ -64,7 +63,7 @@ class TestIgnorePluginFile(unittest.TestCase):
settings, 'PLUGINS_FOLDER', return_value=self.plugin_folder_path
)
- def tearDown(self):
+ def teardown_method(self):
"""
Delete tmp folder
"""
diff --git a/tests/sensors/test_bash.py b/tests/sensors/test_bash.py
index 460e7aa0a2..6cef34a5ff 100644
--- a/tests/sensors/test_bash.py
+++ b/tests/sensors/test_bash.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import datetime
-import unittest
import pytest
@@ -27,8 +26,8 @@ from airflow.models.dag import DAG
from airflow.sensors.bash import BashSensor
-class TestBashSensor(unittest.TestCase):
- def setUp(self):
+class TestBashSensor:
+ def setup_method(self):
args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1,
1)}
dag = DAG('test_dag_id', default_args=args)
self.dag = dag
diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py
index 57b462fa3b..0305263caf 100644
--- a/tests/sensors/test_filesystem.py
+++ b/tests/sensors/test_filesystem.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import os
import shutil
import tempfile
-import unittest
import pytest
@@ -33,8 +32,8 @@ TEST_DAG_ID = 'unit_tests_file_sensor'
DEFAULT_DATE = datetime(2015, 1, 1)
-class TestFileSensor(unittest.TestCase):
- def setUp(self):
+class TestFileSensor:
+ def setup_method(self):
from airflow.hooks.filesystem import FSHook
hook = FSHook()
diff --git a/tests/sensors/test_time_delta.py b/tests/sensors/test_time_delta.py
index 27b95230ab..9c9e256a12 100644
--- a/tests/sensors/test_time_delta.py
+++ b/tests/sensors/test_time_delta.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from airflow.models import DagBag
@@ -30,8 +29,8 @@ DEV_NULL = '/dev/null'
TEST_DAG_ID = 'unit_tests'
-class TestTimedeltaSensor(unittest.TestCase):
- def setUp(self):
+class TestTimedeltaSensor:
+ def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
diff --git a/tests/sensors/test_timeout_sensor.py
b/tests/sensors/test_timeout_sensor.py
index 675a686a31..59798c827f 100644
--- a/tests/sensors/test_timeout_sensor.py
+++ b/tests/sensors/test_timeout_sensor.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import time
-import unittest
from datetime import timedelta
import pytest
@@ -63,8 +62,8 @@ class TimeoutTestSensor(BaseSensorOperator):
self.log.info("Success criteria met. Exiting.")
-class TestSensorTimeout(unittest.TestCase):
- def setUp(self):
+class TestSensorTimeout:
+ def setup_method(self):
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=args)
diff --git a/tests/sensors/test_weekday_sensor.py
b/tests/sensors/test_weekday_sensor.py
index 63be29aab4..868f76ab43 100644
--- a/tests/sensors/test_weekday_sensor.py
+++ b/tests/sensors/test_weekday_sensor.py
@@ -17,10 +17,7 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
-from parameterized import parameterized
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import DagBag
@@ -35,42 +32,43 @@ WEEKDAY_DATE = datetime(2018, 12, 20)
WEEKEND_DATE = datetime(2018, 12, 22)
TEST_DAG_ID = 'weekday_sensor_dag'
DEV_NULL = '/dev/null'
+TEST_CASE_WEEKDAY_SENSOR_TRUE = {
+ "with-string": "Thursday",
+ "with-enum": WeekDay.THURSDAY,
+ "with-enum-set": {WeekDay.THURSDAY},
+ "with-enum-list": [WeekDay.THURSDAY],
+ "with-enum-dict": {WeekDay.THURSDAY: "some_value"},
+ "with-enum-set-2-items": {WeekDay.THURSDAY, WeekDay.FRIDAY},
+ "with-enum-list-2-items": [WeekDay.THURSDAY, WeekDay.FRIDAY],
+ "with-enum-dict-2-items": {WeekDay.THURSDAY: "some_value", WeekDay.FRIDAY:
"some_value_2"},
+ "with-string-set": {"Thursday"},
+ "with-string-set-2-items": {"Thursday", "Friday"},
+ "with-set-mix-types": {"Thursday", WeekDay.FRIDAY},
+ "with-list-mix-types": ["Thursday", WeekDay.FRIDAY],
+ "with-dict-mix-types": {"Thursday": "some_value", WeekDay.FRIDAY:
"some_value_2"},
+}
-class TestDayOfWeekSensor(unittest.TestCase):
+class TestDayOfWeekSensor:
@staticmethod
def clean_db():
db.clear_db_runs()
db.clear_db_task_fail()
- def setUp(self):
+ def setup_method(self):
self.clean_db()
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag = dag
- def tearDown(self):
+ def teardwon_method(self):
self.clean_db()
- @parameterized.expand(
- [
- ("with-string", "Thursday"),
- ("with-enum", WeekDay.THURSDAY),
- ("with-enum-set", {WeekDay.THURSDAY}),
- ("with-enum-list", [WeekDay.THURSDAY]),
- ("with-enum-dict", {WeekDay.THURSDAY: "some_value"}),
- ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}),
- ("with-enum-list-2-items", [WeekDay.THURSDAY, WeekDay.FRIDAY]),
- ("with-enum-dict-2-items", {WeekDay.THURSDAY: "some_value",
WeekDay.FRIDAY: "some_value_2"}),
- ("with-string-set", {"Thursday"}),
- ("with-string-set-2-items", {"Thursday", "Friday"}),
- ("with-set-mix-types", {"Thursday", WeekDay.FRIDAY}),
- ("with-list-mix-types", ["Thursday", WeekDay.FRIDAY]),
- ("with-dict-mix-types", {"Thursday": "some_value", WeekDay.FRIDAY:
"some_value_2"}),
- ]
+ @pytest.mark.parametrize(
+ "week_day", TEST_CASE_WEEKDAY_SENSOR_TRUE.values(),
ids=TEST_CASE_WEEKDAY_SENSOR_TRUE.keys()
)
- def test_weekday_sensor_true(self, _, week_day):
+ def test_weekday_sensor_true(self, week_day):
op = DayOfWeekSensor(
task_id='weekday_sensor_check_true', week_day=week_day,
use_task_logical_date=True, dag=self.dag
)
diff --git a/tests/task/task_runner/test_cgroup_task_runner.py
b/tests/task/task_runner/test_cgroup_task_runner.py
index a9473ce809..79ce5dcd84 100644
--- a/tests/task/task_runner/test_cgroup_task_runner.py
+++ b/tests/task/task_runner/test_cgroup_task_runner.py
@@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner
-class TestCgroupTaskRunner(unittest.TestCase):
+class TestCgroupTaskRunner:
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish")
def test_cgroup_task_runner_super_calls(self, mock_super_on_finish,
mock_super_init):
diff --git a/tests/task/task_runner/test_task_runner.py
b/tests/task/task_runner/test_task_runner.py
index afeafd03b2..c498171563 100644
--- a/tests/task/task_runner/test_task_runner.py
+++ b/tests/task/task_runner/test_task_runner.py
@@ -16,10 +16,9 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
-from parameterized import parameterized
+import pytest
from airflow.task.task_runner import CORE_TASK_RUNNERS, get_task_runner
from airflow.utils.module_loading import import_string
@@ -27,8 +26,8 @@ from airflow.utils.module_loading import import_string
custom_task_runner = mock.MagicMock()
-class GetTaskRunner(unittest.TestCase):
- @parameterized.expand([(import_path,) for import_path in
CORE_TASK_RUNNERS.values()])
+class TestGetTaskRunner:
+ @pytest.mark.parametrize("import_path", CORE_TASK_RUNNERS.values())
def test_should_have_valid_imports(self, import_path):
assert import_string(import_path) is not None
diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
index 6c716d3a5d..61fc43c966 100644
--- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
+++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
@@ -17,14 +17,13 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import Mock
from airflow.models import TaskInstance
from airflow.ti_deps.deps.dag_ti_slots_available_dep import
DagTISlotsAvailableDep
-class TestDagTISlotsAvailableDep(unittest.TestCase):
+class TestDagTISlotsAvailableDep:
def test_concurrency_reached(self):
"""
Test max_active_tasks reached should fail dep
diff --git a/tests/ti_deps/deps/test_dag_unpaused_dep.py
b/tests/ti_deps/deps/test_dag_unpaused_dep.py
index 2aeaed40c5..514c070ac4 100644
--- a/tests/ti_deps/deps/test_dag_unpaused_dep.py
+++ b/tests/ti_deps/deps/test_dag_unpaused_dep.py
@@ -17,14 +17,13 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import Mock
from airflow.models import TaskInstance
from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep
-class TestDagUnpausedDep(unittest.TestCase):
+class TestDagUnpausedDep:
def test_concurrency_reached(self):
"""
Test paused DAG should fail dependency
diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py
b/tests/ti_deps/deps/test_dagrun_exists_dep.py
index 54d98b587c..56347ad187 100644
--- a/tests/ti_deps/deps/test_dagrun_exists_dep.py
+++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import Mock, patch
from airflow.models import DAG, DagRun
@@ -25,7 +24,7 @@ from airflow.ti_deps.deps.dagrun_exists_dep import
DagrunRunningDep
from airflow.utils.state import State
-class TestDagrunRunningDep(unittest.TestCase):
+class TestDagrunRunningDep:
@patch('airflow.models.DagRun.find', return_value=())
def test_dagrun_doesnt_exist(self, mock_dagrun_find):
"""
diff --git a/tests/ti_deps/deps/test_dagrun_id_dep.py
b/tests/ti_deps/deps/test_dagrun_id_dep.py
index 09b8614dd7..36a3049e07 100644
--- a/tests/ti_deps/deps/test_dagrun_id_dep.py
+++ b/tests/ti_deps/deps/test_dagrun_id_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import Mock
from airflow.models.dagrun import DagRun
@@ -25,7 +24,7 @@ from airflow.ti_deps.deps.dagrun_backfill_dep import
DagRunNotBackfillDep
from airflow.utils.types import DagRunType
-class TestDagrunRunningDep(unittest.TestCase):
+class TestDagrunRunningDep:
def test_run_id_is_backfill(self):
"""
Task instances whose run_id is a backfill dagrun run_id should fail
this dep.
diff --git a/tests/ti_deps/deps/test_not_in_retry_period_dep.py
b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
index 07715a3f97..eb4fc90768 100644
--- a/tests/ti_deps/deps/test_not_in_retry_period_dep.py
+++ b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from unittest.mock import Mock
@@ -29,7 +28,7 @@ from airflow.utils.state import State
from airflow.utils.timezone import datetime
-class TestNotInRetryPeriodDep(unittest.TestCase):
+class TestNotInRetryPeriodDep:
def _get_task_instance(self, state, end_date=None,
retry_delay=timedelta(minutes=15)):
task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False)
ti = TaskInstance(task=task, state=state, execution_date=None)
diff --git a/tests/ti_deps/deps/test_pool_slots_available_dep.py
b/tests/ti_deps/deps/test_pool_slots_available_dep.py
index 2cbb16ef1d..0aec7d8c8c 100644
--- a/tests/ti_deps/deps/test_pool_slots_available_dep.py
+++ b/tests/ti_deps/deps/test_pool_slots_available_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest.mock import Mock, patch
from airflow.models import Pool
@@ -27,15 +26,15 @@ from airflow.utils.session import create_session
from tests.test_utils import db
-class TestPoolSlotsAvailableDep(unittest.TestCase):
- def setUp(self):
+class TestPoolSlotsAvailableDep:
+ def setup_method(self):
db.clear_db_pools()
with create_session() as session:
test_pool = Pool(pool='test_pool')
session.add(test_pool)
session.commit()
- def tearDown(self):
+ def teardown_method(self):
db.clear_db_pools()
@patch('airflow.models.Pool.open_slots', return_value=0)
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index 91dbc0deab..722847a391 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from unittest.mock import Mock, patch
@@ -28,7 +27,7 @@ from airflow.utils.state import State
from airflow.utils.timezone import utcnow
-class TestNotInReschedulePeriodDep(unittest.TestCase):
+class TestNotInReschedulePeriodDep:
def _get_task_instance(self, state):
dag = DAG('test_dag')
task = Mock(dag=dag, reschedule=True, is_mapped=False)
diff --git a/tests/ti_deps/deps/test_task_concurrency.py
b/tests/ti_deps/deps/test_task_concurrency.py
index 55f8bd8858..f694beb430 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime
from unittest.mock import Mock
@@ -27,7 +26,7 @@ from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
-class TestTaskConcurrencyDep(unittest.TestCase):
+class TestTaskConcurrencyDep:
def _get_task(self, **kwargs):
return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
diff --git a/tests/ti_deps/deps/test_task_not_running_dep.py
b/tests/ti_deps/deps/test_task_not_running_dep.py
index 9a401d3e69..62a1a1f59d 100644
--- a/tests/ti_deps/deps/test_task_not_running_dep.py
+++ b/tests/ti_deps/deps/test_task_not_running_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime
from unittest.mock import Mock
@@ -25,7 +24,7 @@ from airflow.ti_deps.deps.task_not_running_dep import
TaskNotRunningDep
from airflow.utils.state import State
-class TestTaskNotRunningDep(unittest.TestCase):
+class TestTaskNotRunningDep:
def test_not_running_state(self):
ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1))
assert TaskNotRunningDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_valid_state_dep.py
b/tests/ti_deps/deps/test_valid_state_dep.py
index f4528212b4..9fb8623417 100644
--- a/tests/ti_deps/deps/test_valid_state_dep.py
+++ b/tests/ti_deps/deps/test_valid_state_dep.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime
from unittest.mock import Mock
@@ -28,7 +27,7 @@ from airflow.ti_deps.deps.valid_state_dep import ValidStateDep
from airflow.utils.state import State
-class TestValidStateDep(unittest.TestCase):
+class TestValidStateDep:
def test_valid_state(self):
"""
Valid state should pass this dep
diff --git a/tests/utils/log/test_file_processor_handler.py
b/tests/utils/log/test_file_processor_handler.py
index ed563e5c83..1d479c71da 100644
--- a/tests/utils/log/test_file_processor_handler.py
+++ b/tests/utils/log/test_file_processor_handler.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import os
import shutil
-import unittest
from datetime import timedelta
from freezegun import freeze_time
@@ -28,9 +27,8 @@ from airflow.utils import timezone
from airflow.utils.log.file_processor_handler import FileProcessorHandler
-class TestFileProcessorHandler(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class TestFileProcessorHandler:
+ def setup_method(self):
self.base_log_folder = "/tmp/log_test"
self.filename = "{filename}"
self.filename_template = "{{ filename }}.log"
@@ -109,5 +107,5 @@ class TestFileProcessorHandler(unittest.TestCase):
with freeze_time(date1):
handler.set_context(filename=os.path.join(self.dag_dir, "log1"))
- def tearDown(self):
+ def teardown_method(self):
shutil.rmtree(self.base_log_folder, ignore_errors=True)
diff --git a/tests/utils/log/test_json_formatter.py
b/tests/utils/log/test_json_formatter.py
index 98b409db6d..627da568b0 100644
--- a/tests/utils/log/test_json_formatter.py
+++ b/tests/utils/log/test_json_formatter.py
@@ -22,13 +22,12 @@ from __future__ import annotations
import json
import sys
-import unittest
from logging import makeLogRecord
from airflow.utils.log.json_formatter import JSONFormatter
-class TestJSONFormatter(unittest.TestCase):
+class TestJSONFormatter:
"""
TestJSONFormatter class combine all tests for JSONFormatter
"""
diff --git a/tests/utils/test_dag_cycle.py b/tests/utils/test_dag_cycle.py
index f6012bc2ea..731ea707f7 100644
--- a/tests/utils/test_dag_cycle.py
+++ b/tests/utils/test_dag_cycle.py
@@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
from airflow import DAG
@@ -29,7 +27,7 @@ from airflow.utils.task_group import TaskGroup
from tests.models import DEFAULT_DATE
-class TestCycleTester(unittest.TestCase):
+class TestCycleTester:
def test_cycle_empty(self):
# test empty
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner':
'owner1'})
diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py
index ae016cda63..029bfb5582 100644
--- a/tests/utils/test_dates.py
+++ b/tests/utils/test_dates.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime, timedelta
import pendulum
@@ -28,7 +27,7 @@ from pytest import approx
from airflow.utils import dates, timezone
-class TestDates(unittest.TestCase):
+class TestDates:
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_days_ago(self):
today = pendulum.today()
@@ -106,7 +105,7 @@ class TestDates(unittest.TestCase):
assert arr4 == approx([2.3147, 1.1574], rel=1e-3)
-class TestUtilsDatesDateRange(unittest.TestCase):
+class TestUtilsDatesDateRange:
def test_no_delta(self):
assert dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)) ==
[]
diff --git a/tests/utils/test_docs.py b/tests/utils/test_docs.py
index 852891aca1..3e3fe6df33 100644
--- a/tests/utils/test_docs.py
+++ b/tests/utils/test_docs.py
@@ -16,16 +16,16 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
-from parameterized import parameterized
+import pytest
from airflow.utils.docs import get_docs_url
-class TestGetDocsUrl(unittest.TestCase):
- @parameterized.expand(
+class TestGetDocsUrl:
+ @pytest.mark.parametrize(
+ "version, page, expected_url",
[
(
'2.0.0.dev0',
@@ -45,7 +45,7 @@ class TestGetDocsUrl(unittest.TestCase):
'project.html',
'https://airflow.apache.org/docs/apache-airflow/1.10.10/project.html',
),
- ]
+ ],
)
def test_should_return_link(self, version, page, expected_url):
with mock.patch('airflow.version.version', version):
diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py
index 45d70d1161..1a29b4f8d4 100644
--- a/tests/utils/test_email.py
+++ b/tests/utils/test_email.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import os
import tempfile
-import unittest
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
@@ -37,7 +36,7 @@ EMAILS = ['[email protected]', '[email protected]']
send_email_test = mock.MagicMock()
-class TestEmail(unittest.TestCase):
+class TestEmail:
def test_get_email_address_single_email(self):
emails_string = '[email protected]'
@@ -147,7 +146,7 @@ class TestEmail(unittest.TestCase):
assert msg['To'] == ','.join(recipients)
-class TestEmailSmtp(unittest.TestCase):
+class TestEmailSmtp:
@mock.patch('airflow.utils.email.send_mime_email')
def test_send_smtp(self, mock_send_mime):
with tempfile.NamedTemporaryFile() as attachment:
diff --git a/tests/utils/test_event_scheduler.py
b/tests/utils/test_event_scheduler.py
index 7e126bad26..641d8dd0f9 100644
--- a/tests/utils/test_event_scheduler.py
+++ b/tests/utils/test_event_scheduler.py
@@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.utils.event_scheduler import EventScheduler
-class TestEventScheduler(unittest.TestCase):
+class TestEventScheduler:
def test_call_regular_interval(self):
somefunction = mock.MagicMock()
diff --git a/tests/utils/test_file.py b/tests/utils/test_file.py
index f403408263..2036fbce7a 100644
--- a/tests/utils/test_file.py
+++ b/tests/utils/test_file.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import os
import os.path
-import unittest
from pathlib import Path
from unittest import mock
@@ -29,7 +28,7 @@ from airflow.utils.file import correct_maybe_zipped,
find_path_from_directory, o
from tests.models import TEST_DAGS_FOLDER
-class TestCorrectMaybeZipped(unittest.TestCase):
+class TestCorrectMaybeZipped:
@mock.patch("zipfile.is_zipfile")
def test_correct_maybe_zipped_normal_file(self, mocked_is_zipfile):
path = '/path/to/some/file.txt'
@@ -62,7 +61,7 @@ class TestCorrectMaybeZipped(unittest.TestCase):
assert dag_folder == '/path/to/archive.zip'
-class TestOpenMaybeZipped(unittest.TestCase):
+class TestOpenMaybeZipped:
def test_open_maybe_zipped_normal_file(self):
test_file_path = os.path.join(TEST_DAGS_FOLDER, "no_dags.py")
with open_maybe_zipped(test_file_path, 'r') as test_file:
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index f0e65067f9..adbf6d1a3a 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -19,18 +19,16 @@ from __future__ import annotations
import decimal
import json
-import unittest
from datetime import date, datetime
import numpy as np
-import parameterized
import pendulum
import pytest
from airflow.utils import json as utils_json
-class TestAirflowJsonEncoder(unittest.TestCase):
+class TestAirflowJsonEncoder:
def test_encode_datetime(self):
obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S')
assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) ==
'"2017-05-21T00:00:00+00:00"'
@@ -42,7 +40,8 @@ class TestAirflowJsonEncoder(unittest.TestCase):
def test_encode_date(self):
assert json.dumps(date(2017, 5, 21),
cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"'
- @parameterized.parameterized.expand(
+ @pytest.mark.parametrize(
+ "expr, expected",
[("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"),
("12.34", "12.34")],
)
def test_encode_decimal(self, expr, expected):
diff --git a/tests/utils/test_logging_mixin.py
b/tests/utils/test_logging_mixin.py
index ca736e9cfc..567729bdd8 100644
--- a/tests/utils/test_logging_mixin.py
+++ b/tests/utils/test_logging_mixin.py
@@ -17,15 +17,14 @@
# under the License.
from __future__ import annotations
-import unittest
import warnings
from unittest import mock
from airflow.utils.log.logging_mixin import StreamLogWriter, set_context
-class TestLoggingMixin(unittest.TestCase):
- def setUp(self):
+class TestLoggingMixin:
+ def setup_method(self):
warnings.filterwarnings(action='always')
def test_set_context(self):
@@ -53,7 +52,7 @@ class TestLoggingMixin(unittest.TestCase):
warnings.resetwarnings()
-class TestStreamLogWriter(unittest.TestCase):
+class TestStreamLogWriter:
def test_write(self):
logger = mock.MagicMock()
logger.log = mock.MagicMock()
diff --git a/tests/utils/test_module_loading.py
b/tests/utils/test_module_loading.py
index bdb2c3af4a..2c52b66236 100644
--- a/tests/utils/test_module_loading.py
+++ b/tests/utils/test_module_loading.py
@@ -17,14 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
from airflow.utils.module_loading import import_string
-class TestModuleImport(unittest.TestCase):
+class TestModuleImport:
def test_import_string(self):
cls = import_string('airflow.utils.module_loading.import_string')
assert cls == import_string
diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py
index bdd5a0ac9e..c633f511cb 100644
--- a/tests/utils/test_net.py
+++ b/tests/utils/test_net.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import re
-import unittest
from unittest import mock
import pytest
@@ -32,7 +31,7 @@ def get_hostname():
return 'awesomehostname'
-class TestGetHostname(unittest.TestCase):
+class TestGetHostname:
@mock.patch('airflow.utils.net.getfqdn', return_value='first')
@conf_vars({('core', 'hostname_callable'): None})
def test_get_hostname_unset(self, mock_getfqdn):
diff --git a/tests/utils/test_operator_helpers.py
b/tests/utils/test_operator_helpers.py
index c590b51f68..96368eaf7f 100644
--- a/tests/utils/test_operator_helpers.py
+++ b/tests/utils/test_operator_helpers.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import datetime
from unittest import mock
@@ -26,9 +25,8 @@ import pytest
from airflow.utils import operator_helpers
-class TestOperatorHelpers(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class TestOperatorHelpers:
+ def setup_method(self):
self.dag_id = 'dag_id'
self.task_id = 'task_id'
self.try_number = 1
diff --git a/tests/utils/test_operator_resources.py
b/tests/utils/test_operator_resources.py
index 64b5af3295..dd81f4a281 100644
--- a/tests/utils/test_operator_resources.py
+++ b/tests/utils/test_operator_resources.py
@@ -17,12 +17,10 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.utils.operator_resources import Resources
-class TestResources(unittest.TestCase):
+class TestResources:
def test_resource_eq(self):
r = Resources(cpus=0.1, ram=2048)
assert r not in [{}, [], None]
diff --git a/tests/utils/test_preexisting_python_virtualenv_decorator.py
b/tests/utils/test_preexisting_python_virtualenv_decorator.py
index 1b54fa45c0..3934342062 100644
--- a/tests/utils/test_preexisting_python_virtualenv_decorator.py
+++ b/tests/utils/test_preexisting_python_virtualenv_decorator.py
@@ -17,12 +17,10 @@
# under the License.
from __future__ import annotations
-import unittest
-
from airflow.utils.decorators import remove_task_decorator
-class TestExternalPythonDecorator(unittest.TestCase):
+class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = "@task.external_python(use_dill=True)\ndef f():\nimport
funcsigs"
res = remove_task_decorator(python_source=py_source,
task_decorator_name="@task.external_python")
diff --git a/tests/utils/test_python_virtualenv.py
b/tests/utils/test_python_virtualenv.py
index 89877de793..d7d7970bf1 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -18,14 +18,13 @@
from __future__ import annotations
import sys
-import unittest
from unittest import mock
from airflow.utils.decorators import remove_task_decorator
from airflow.utils.python_virtualenv import prepare_virtualenv
-class TestPrepareVirtualenv(unittest.TestCase):
+class TestPrepareVirtualenv:
@mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess')
def test_should_create_virtualenv(self, mock_execute_in_subprocess):
python_bin = prepare_virtualenv(
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index f8d9f21ab7..86c320f666 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -19,13 +19,11 @@ from __future__ import annotations
import datetime
import pickle
-import unittest
from unittest import mock
from unittest.mock import MagicMock
import pytest
from kubernetes.client import models as k8s
-from parameterized import parameterized
from pytest import param
from sqlalchemy.exc import StatementError
@@ -41,8 +39,8 @@ from airflow.utils.timezone import utcnow
TEST_POD =
k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
-class TestSqlAlchemyUtils(unittest.TestCase):
- def setUp(self):
+class TestSqlAlchemyUtils:
+ def setup_method(self):
session = Session()
# make sure NOT to run in UTC. Only postgres supports storing
@@ -108,7 +106,8 @@ class TestSqlAlchemyUtils(unittest.TestCase):
)
dag.clear()
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "dialect, supports_for_update_of, expected_return_value",
[
(
"postgresql",
@@ -130,7 +129,7 @@ class TestSqlAlchemyUtils(unittest.TestCase):
False,
{'skip_locked': True},
),
- ]
+ ],
)
def test_skip_locked(self, dialect, supports_for_update_of,
expected_return_value):
session = mock.Mock()
@@ -138,7 +137,8 @@ class TestSqlAlchemyUtils(unittest.TestCase):
session.bind.dialect.supports_for_update_of = supports_for_update_of
assert skip_locked(session=session) == expected_return_value
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "dialect, supports_for_update_of, expected_return_value",
[
(
"postgresql",
@@ -162,7 +162,7 @@ class TestSqlAlchemyUtils(unittest.TestCase):
'nowait': True,
},
),
- ]
+ ],
)
def test_nowait(self, dialect, supports_for_update_of,
expected_return_value):
session = mock.Mock()
@@ -170,7 +170,8 @@ class TestSqlAlchemyUtils(unittest.TestCase):
session.bind.dialect.supports_for_update_of = supports_for_update_of
assert nowait(session=session) == expected_return_value
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "dialect, supports_for_update_of, use_row_level_lock_conf,
expected_use_row_level_lock",
[
("postgresql", True, True, True),
("postgresql", True, False, False),
@@ -179,7 +180,7 @@ class TestSqlAlchemyUtils(unittest.TestCase):
("mysql", True, True, True),
("mysql", True, False, False),
("sqlite", False, True, True),
- ]
+ ],
)
def test_with_row_locks(
self, dialect, supports_for_update_of, use_row_level_lock_conf,
expected_use_row_level_lock
@@ -232,7 +233,7 @@ class TestSqlAlchemyUtils(unittest.TestCase):
other_session.execute('SELECT 1')
other_session.commit()
- def tearDown(self):
+ def teardown_method(self):
self.session.close()
settings.engine.dispose()
diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py
index 729e514fa1..e006d990df 100644
--- a/tests/utils/test_timezone.py
+++ b/tests/utils/test_timezone.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import datetime
-import unittest
import pendulum
import pytest
@@ -32,7 +31,7 @@ ICT = pendulum.tz.timezone('Asia/Bangkok') # Asia/Bangkok
UTC = timezone.utc
-class TestTimezone(unittest.TestCase):
+class TestTimezone:
def test_is_aware(self):
assert timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30,
tzinfo=EAT))
assert not timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20,
30))
diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py
index f8f598c398..44866f8a8b 100644
--- a/tests/utils/test_trigger_rule.py
+++ b/tests/utils/test_trigger_rule.py
@@ -17,14 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
from airflow.utils.trigger_rule import TriggerRule
-class TestTriggerRule(unittest.TestCase):
+class TestTriggerRule:
def test_valid_trigger_rules(self):
assert TriggerRule.is_valid(TriggerRule.ALL_SUCCESS)
assert TriggerRule.is_valid(TriggerRule.ALL_FAILED)
diff --git a/tests/utils/test_weekday.py b/tests/utils/test_weekday.py
index c478e834de..0dcadfd1ac 100644
--- a/tests/utils/test_weekday.py
+++ b/tests/utils/test_weekday.py
@@ -17,16 +17,14 @@
# under the License.
from __future__ import annotations
-import unittest
from enum import Enum
import pytest
-from parameterized import parameterized
from airflow.utils.weekday import WeekDay
-class TestWeekDay(unittest.TestCase):
+class TestWeekDay:
def test_weekday_enum_length(self):
assert len(WeekDay) == 7
@@ -44,36 +42,44 @@ class TestWeekDay(unittest.TestCase):
assert isinstance(weekday_enum, int)
assert isinstance(weekday_enum, Enum)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "weekday, expected",
[
- ("with-string", "Monday", 1),
- ("with-enum", WeekDay.MONDAY, 1),
- ]
+ ("Monday", 1),
+ (WeekDay.MONDAY, 1),
+ ],
+ ids=["with-string", "with-enum"],
)
- def test_convert(self, _, weekday, expected):
+ def test_convert(self, weekday, expected):
result = WeekDay.convert(weekday)
- self.assertEqual(result, expected)
+ assert result == expected
def test_convert_with_incorrect_input(self):
invalid = "Sun"
- with self.assertRaisesRegex(
- AttributeError,
- f'Invalid Week Day passed: "{invalid}"',
- ):
+ error_message = fr'Invalid Week Day passed: "{invalid}"'
+ with pytest.raises(AttributeError, match=error_message):
WeekDay.convert(invalid)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "weekday, expected",
[
- ("with-string", "Monday", {WeekDay.MONDAY}),
- ("with-enum", WeekDay.MONDAY, {WeekDay.MONDAY}),
- ("with-dict", {"Thursday": "1"}, {WeekDay.THURSDAY}),
- ("with-list", ["Thursday"], {WeekDay.THURSDAY}),
- ("with-mix", ["Thursday", WeekDay.MONDAY], {WeekDay.MONDAY,
WeekDay.THURSDAY}),
- ]
+ ("Monday", {WeekDay.MONDAY}),
+ (WeekDay.MONDAY, {WeekDay.MONDAY}),
+ ({"Thursday": "1"}, {WeekDay.THURSDAY}),
+ (["Thursday"], {WeekDay.THURSDAY}),
+ (["Thursday", WeekDay.MONDAY], {WeekDay.MONDAY, WeekDay.THURSDAY}),
+ ],
+ ids=[
+ "with-string",
+ "with-enum",
+ "with-dict",
+ "with-list",
+ "with-mix",
+ ],
)
- def test_validate_week_day(self, _, weekday, expected):
+ def test_validate_week_day(self, weekday, expected):
result = WeekDay.validate_week_day(weekday)
- self.assertEqual(expected, result)
+ assert expected == result
def test_validate_week_day_with_invalid_type(self):
invalid_week_day = 5
diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py
index 7be17e7604..73abafe782 100644
--- a/tests/utils/test_weight_rule.py
+++ b/tests/utils/test_weight_rule.py
@@ -17,14 +17,12 @@
# under the License.
from __future__ import annotations
-import unittest
-
import pytest
from airflow.utils.weight_rule import WeightRule
-class TestWeightRule(unittest.TestCase):
+class TestWeightRule:
def test_valid_weight_rules(self):
assert WeightRule.is_valid(WeightRule.DOWNSTREAM)
assert WeightRule.is_valid(WeightRule.UPSTREAM)
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index d82dda1d7a..106001cc84 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import runpy
import sys
-import unittest
from datetime import timedelta
from unittest import mock
@@ -33,9 +32,9 @@ from tests.test_utils.config import conf_vars
from tests.test_utils.decorators import dont_initialize_flask_app_submodules
-class TestApp(unittest.TestCase):
+class TestApp:
@classmethod
- def setUpClass(cls) -> None:
+ def setup_class(cls) -> None:
from airflow import settings
settings.configure_orm()
diff --git a/tests/www/test_init_views.py b/tests/www/test_init_views.py
index 7f23e439e1..49d06ee04e 100644
--- a/tests/www/test_init_views.py
+++ b/tests/www/test_init_views.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import re
-import unittest
from unittest import mock
import pytest
@@ -26,7 +25,7 @@ from airflow.www.extensions import init_views
from tests.test_utils.config import conf_vars
-class TestInitApiExperimental(unittest.TestCase):
+class TestInitApiExperimental:
@conf_vars({('api', 'enable_experimental_api'): 'true'})
def test_should_raise_deprecation_warning_when_enabled(self):
app = mock.MagicMock()
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index 1fc59567f7..f1a2012a94 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import re
-import unittest
from datetime import datetime
from urllib.parse import parse_qs
@@ -28,7 +27,7 @@ from airflow.www import utils
from airflow.www.utils import wrapped_markdown
-class TestUtils(unittest.TestCase):
+class TestUtils:
def check_generate_pages_html(self, current_page, total_pages, window=7,
check_middle=False):
extra_links = 4 # first, prev, next, last
search = "'>\"/><img src=x onerror=alert(1)>"
@@ -156,8 +155,8 @@ class TestUtils(unittest.TestCase):
assert '<b2>' not in html
-class TestAttrRenderer(unittest.TestCase):
- def setUp(self):
+class TestAttrRenderer:
+ def setup_method(self):
self.attr_renderer = utils.get_attr_renderer()
def test_python_callable(self):
@@ -178,11 +177,11 @@ class TestAttrRenderer(unittest.TestCase):
assert "<li>bar</li>" in rendered
def test_markdown_none(self):
- rendered = self.attr_renderer["python_callable"](None)
- assert "" == rendered
+ rendered = self.attr_renderer["doc_md"](None)
+ assert rendered is None
-class TestWrappedMarkdown(unittest.TestCase):
+class TestWrappedMarkdown:
def test_wrapped_markdown_with_docstring_curly_braces(self):
rendered = wrapped_markdown("{braces}", css_class="a_class")
assert (
diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py
index eaac039436..499c9a8608 100644
--- a/tests/www/test_validators.py
+++ b/tests/www/test_validators.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
import pytest
@@ -25,9 +24,8 @@ import pytest
from airflow.www import validators
-class TestGreaterEqualThan(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class TestGreaterEqualThan:
+ def setup_method(self):
self.form_field_mock = mock.MagicMock(data='2017-05-06')
self.form_field_mock.gettext.side_effect = lambda msg: msg
self.other_field_mock = mock.MagicMock(data='2017-05-05')
@@ -89,9 +87,8 @@ class TestGreaterEqualThan(unittest.TestCase):
)
-class TestValidJson(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class TestValidJson:
+ def setup_method(self):
self.form_field_mock = mock.MagicMock(data='{"valid":"True"}')
self.form_field_mock.gettext.side_effect = lambda msg: msg
self.form_mock = mock.MagicMock(spec_set=dict)