This is an automated email from the ASF dual-hosted git repository. johnbodley pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push: new bc65c245fe chore(tests): Remove unnecessary/problematic app contexts (#28159) bc65c245fe is described below commit bc65c245fef814491a99967296e45956c68d590b Author: John Bodley <4567245+john-bod...@users.noreply.github.com> AuthorDate: Wed Apr 24 13:46:35 2024 -0700 chore(tests): Remove unnecessary/problematic app contexts (#28159) --- tests/integration_tests/access_tests.py | 30 -- .../annotation_layers/fixtures.py | 59 ++- tests/integration_tests/charts/api_tests.py | 242 ++++++----- tests/integration_tests/charts/data/api_tests.py | 8 +- tests/integration_tests/conftest.py | 11 +- tests/integration_tests/dashboard_tests.py | 63 ++- tests/integration_tests/dashboards/api_tests.py | 6 - .../dashboards/security/security_dataset_tests.py | 63 ++- .../databases/commands/upload_test.py | 15 +- tests/integration_tests/embedded/dao_tests.py | 1 - tests/integration_tests/fixtures/datasource.py | 59 ++- tests/integration_tests/fixtures/public_role.py | 27 +- tests/integration_tests/fixtures/tags.py | 16 +- tests/integration_tests/fixtures/users.py | 42 +- .../fixtures/world_bank_dashboard.py | 13 +- ...migrate_can_view_and_drill_permission__tests.py | 36 +- tests/integration_tests/reports/alert_tests.py | 62 ++- tests/integration_tests/reports/commands_tests.py | 463 ++++++++++----------- tests/integration_tests/reports/scheduler_tests.py | 249 ++++++----- .../security/analytics_db_safety_tests.py | 11 +- tests/integration_tests/sqla_models_tests.py | 46 +- tests/integration_tests/utils/core_tests.py | 3 +- tests/integration_tests/utils_tests.py | 32 +- tests/unit_tests/databases/api_test.py | 367 ++++++++-------- 24 files changed, 912 insertions(+), 1012 deletions(-) diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 37461c8ca6..6ece4a081b 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -81,36 +81,6 @@ DB_ACCESS_ROLE = "db_access_role" SCHEMA_ACCESS_ROLE = "schema_access_role" -class TestRequestAccess(SupersetTestCase): - @classmethod - def setUpClass(cls): - with app.app_context(): - security_manager.add_role("override_me") - security_manager.add_role(TEST_ROLE_1) - security_manager.add_role(TEST_ROLE_2) - security_manager.add_role(DB_ACCESS_ROLE) - security_manager.add_role(SCHEMA_ACCESS_ROLE) - db.session.commit() - - @classmethod - def tearDownClass(cls): - with app.app_context(): - override_me = security_manager.find_role("override_me") - db.session.delete(override_me) - db.session.delete(security_manager.find_role(TEST_ROLE_1)) - db.session.delete(security_manager.find_role(TEST_ROLE_2)) - db.session.delete(security_manager.find_role(DB_ACCESS_ROLE)) - db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - db.session.commit() - - def tearDown(self): - override_me = security_manager.find_role("override_me") - override_me.permissions = [] - db.session.commit() - db.session.close() - super().tearDown() - - @pytest.mark.parametrize( "username,user_id", [ diff --git a/tests/integration_tests/annotation_layers/fixtures.py b/tests/integration_tests/annotation_layers/fixtures.py index 8243d7e474..ac25d28d42 100644 --- a/tests/integration_tests/annotation_layers/fixtures.py +++ b/tests/integration_tests/annotation_layers/fixtures.py @@ -14,18 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# isort:skip_file -import pytest from datetime import datetime from typing import Optional +import pytest +from flask.ctx import AppContext from superset import db from superset.models.annotations import Annotation, AnnotationLayer - from tests.integration_tests.test_app import app - ANNOTATION_LAYERS_COUNT = 10 ANNOTATIONS_COUNT = 5 @@ -70,36 +68,35 @@ def _insert_annotation( @pytest.fixture() -def create_annotation_layers(): +def create_annotation_layers(app_context: AppContext): """ Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations and a final one with ANNOTATION_COUNT children :return: """ - with app.app_context(): - annotation_layers = [] - annotations = [] - for cx in range(ANNOTATION_LAYERS_COUNT - 1): - annotation_layers.append( - _insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}") + annotation_layers = [] + annotations = [] + for cx in range(ANNOTATION_LAYERS_COUNT - 1): + annotation_layers.append( + _insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}") + ) + layer_with_annotations = _insert_annotation_layer("layer_with_annotations") + annotation_layers.append(layer_with_annotations) + for cx in range(ANNOTATIONS_COUNT): + annotations.append( + _insert_annotation( + layer_with_annotations, + short_descr=f"short_descr{cx}", + long_descr=f"long_descr{cx}", + start_dttm=get_start_dttm(cx), + end_dttm=get_end_dttm(cx), ) - layer_with_annotations = _insert_annotation_layer("layer_with_annotations") - annotation_layers.append(layer_with_annotations) - for cx in range(ANNOTATIONS_COUNT): - annotations.append( - _insert_annotation( - layer_with_annotations, - short_descr=f"short_descr{cx}", - long_descr=f"long_descr{cx}", - start_dttm=get_start_dttm(cx), - end_dttm=get_end_dttm(cx), - ) - ) - yield annotation_layers - - # rollback changes - for annotation_layer in annotation_layers: - db.session.delete(annotation_layer) - for annotation in annotations: - db.session.delete(annotation) - db.session.commit() + ) + yield annotation_layers + + # rollback changes + for annotation_layer in annotation_layers: + db.session.delete(annotation_layer) + for annotation in annotations: + db.session.delete(annotation) + db.session.commit() diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index f607e014ec..16d44fe5cd 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -23,6 +23,7 @@ from zipfile import is_zipfile, ZipFile import prison import pytest import yaml +from flask.ctx import AppContext from flask_babel import lazy_gettext as _ from parameterized import parameterized from sqlalchemy import and_ @@ -82,121 +83,115 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase): resource_name = "chart" @pytest.fixture(autouse=True) - def clear_data_cache(self): - with app.app_context(): - cache_manager.data_cache.clear() - yield + def clear_data_cache(self, app_context: AppContext): + cache_manager.data_cache.clear() + yield @pytest.fixture() def create_charts(self): - with self.create_app().app_context(): - charts = [] - admin = self.get_user("admin") - for cx in range(CHARTS_FIXTURE_COUNT - 1): - charts.append(self.insert_chart(f"name{cx}", [admin.id], 1)) - fav_charts = [] - for cx in range(round(CHARTS_FIXTURE_COUNT / 2)): - fav_star = FavStar( - user_id=admin.id, class_name="slice", obj_id=charts[cx].id - ) - db.session.add(fav_star) - db.session.commit() - fav_charts.append(fav_star) - yield charts - - # rollback changes - for chart in charts: - db.session.delete(chart) - for fav_chart in fav_charts: - db.session.delete(fav_chart) + charts = [] + admin = self.get_user("admin") + for cx in range(CHARTS_FIXTURE_COUNT - 1): + charts.append(self.insert_chart(f"name{cx}", [admin.id], 1)) + fav_charts = [] + for cx in range(round(CHARTS_FIXTURE_COUNT / 2)): + fav_star = FavStar( + user_id=admin.id, class_name="slice", obj_id=charts[cx].id + ) + db.session.add(fav_star) db.session.commit() + fav_charts.append(fav_star) + yield charts + + # rollback changes + for chart in charts: + db.session.delete(chart) + for fav_chart in fav_charts: + db.session.delete(fav_chart) + db.session.commit() @pytest.fixture() def create_charts_created_by_gamma(self): - with self.create_app().app_context(): - charts = [] - user = self.get_user("gamma") - for cx in range(CHARTS_FIXTURE_COUNT - 1): - charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1)) - yield charts - # rollback changes - for chart in charts: - db.session.delete(chart) - db.session.commit() + charts = [] + user = self.get_user("gamma") + for cx in range(CHARTS_FIXTURE_COUNT - 1): + charts.append(self.insert_chart(f"gamma{cx}", [user.id], 1)) + yield charts + # rollback changes + for chart in charts: + db.session.delete(chart) + db.session.commit() @pytest.fixture() def create_certified_charts(self): - with self.create_app().app_context(): - certified_charts = [] - admin = self.get_user("admin") - for cx in range(CHARTS_FIXTURE_COUNT): - certified_charts.append( - self.insert_chart( - f"certified{cx}", - [admin.id], - 1, - certified_by="John Doe", - certification_details="Sample certification", - ) + certified_charts = [] + admin = self.get_user("admin") + for cx in range(CHARTS_FIXTURE_COUNT): + certified_charts.append( + self.insert_chart( + f"certified{cx}", + [admin.id], + 1, + certified_by="John Doe", + certification_details="Sample certification", ) + ) - yield certified_charts + yield certified_charts - # rollback changes - for chart in certified_charts: - db.session.delete(chart) - db.session.commit() + # rollback changes + for chart in certified_charts: + db.session.delete(chart) + db.session.commit() @pytest.fixture() def create_chart_with_report(self): - with self.create_app().app_context(): - admin = self.get_user("admin") - chart = self.insert_chart(f"chart_report", [admin.id], 1) - report_schedule = ReportSchedule( - type=ReportScheduleType.REPORT, - name="report_with_chart", - crontab="* * * * *", - chart=chart, - ) - db.session.commit() + admin = self.get_user("admin") + chart = self.insert_chart(f"chart_report", [admin.id], 1) + report_schedule = ReportSchedule( + type=ReportScheduleType.REPORT, + name="report_with_chart", + crontab="* * * * *", + chart=chart, + ) + db.session.commit() - yield chart + yield chart - # rollback changes - db.session.delete(report_schedule) - db.session.delete(chart) - db.session.commit() + # rollback changes + db.session.delete(report_schedule) + db.session.delete(chart) + db.session.commit() @pytest.fixture() def add_dashboard_to_chart(self): - with self.create_app().app_context(): - admin = self.get_user("admin") - - self.chart = self.insert_chart("My chart", [admin.id], 1) - - self.original_dashboard = Dashboard() - self.original_dashboard.dashboard_title = "Original Dashboard" - self.original_dashboard.slug = "slug" - self.original_dashboard.owners = [admin] - self.original_dashboard.slices = [self.chart] - self.original_dashboard.published = False - db.session.add(self.original_dashboard) - - self.new_dashboard = Dashboard() - self.new_dashboard.dashboard_title = "New Dashboard" - self.new_dashboard.slug = "new_slug" - self.new_dashboard.owners = [admin] - self.new_dashboard.published = False - db.session.add(self.new_dashboard) + admin = self.get_user("admin") - db.session.commit() + self.chart = self.insert_chart("My chart", [admin.id], 1) - yield self.chart + self.original_dashboard = Dashboard() + self.original_dashboard.dashboard_title = "Original Dashboard" + self.original_dashboard.slug = "slug" + self.original_dashboard.owners = [admin] + self.original_dashboard.slices = [self.chart] + self.original_dashboard.published = False + db.session.add(self.original_dashboard) - db.session.delete(self.original_dashboard) - db.session.delete(self.new_dashboard) - db.session.delete(self.chart) - db.session.commit() + self.new_dashboard = Dashboard() + self.new_dashboard.dashboard_title = "New Dashboard" + self.new_dashboard.slug = "new_slug" + self.new_dashboard.owners = [admin] + self.new_dashboard.published = False + db.session.add(self.new_dashboard) + + db.session.commit() + + yield self.chart + + db.session.delete(self.original_dashboard) + db.session.delete(self.new_dashboard) + db.session.delete(self.chart) + db.session.commit() def test_info_security_chart(self): """ @@ -1127,40 +1122,39 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase): @pytest.fixture() def load_energy_charts(self): - with app.app_context(): - admin = self.get_user("admin") - energy_table = ( - db.session.query(SqlaTable) - .filter_by(table_name="energy_usage") - .one_or_none() - ) - energy_table_id = 1 - if energy_table: - energy_table_id = energy_table.id - chart1 = self.insert_chart( - "foo_a", [admin.id], energy_table_id, description="ZY_bar" - ) - chart2 = self.insert_chart( - "zy_foo", [admin.id], energy_table_id, description="desc1" - ) - chart3 = self.insert_chart( - "foo_b", [admin.id], energy_table_id, description="desc1zy_" - ) - chart4 = self.insert_chart( - "foo_c", [admin.id], energy_table_id, viz_type="viz_zy_" - ) - chart5 = self.insert_chart( - "bar", [admin.id], energy_table_id, description="foo" - ) + admin = self.get_user("admin") + energy_table = ( + db.session.query(SqlaTable) + .filter_by(table_name="energy_usage") + .one_or_none() + ) + energy_table_id = 1 + if energy_table: + energy_table_id = energy_table.id + chart1 = self.insert_chart( + "foo_a", [admin.id], energy_table_id, description="ZY_bar" + ) + chart2 = self.insert_chart( + "zy_foo", [admin.id], energy_table_id, description="desc1" + ) + chart3 = self.insert_chart( + "foo_b", [admin.id], energy_table_id, description="desc1zy_" + ) + chart4 = self.insert_chart( + "foo_c", [admin.id], energy_table_id, viz_type="viz_zy_" + ) + chart5 = self.insert_chart( + "bar", [admin.id], energy_table_id, description="foo" + ) - yield - # rollback changes - db.session.delete(chart1) - db.session.delete(chart2) - db.session.delete(chart3) - db.session.delete(chart4) - db.session.delete(chart5) - db.session.commit() + yield + # rollback changes + db.session.delete(chart1) + db.session.delete(chart2) + db.session.delete(chart3) + db.session.delete(chart4) + db.session.delete(chart5) + db.session.commit() @pytest.mark.usefixtures("load_energy_charts") def test_get_charts_custom_filter(self): diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 061cdece50..1dd5e7113c 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -27,6 +27,7 @@ from unittest import mock from zipfile import ZipFile from flask import Response +from flask.ctx import AppContext from tests.integration_tests.conftest import with_feature_flags from superset.charts.data.api import ChartDataRestApi from superset.models.sql_lab import Query @@ -88,10 +89,9 @@ INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = { @pytest.fixture(autouse=True) -def skip_by_backend(): - with app.app_context(): - if backend() == "hive": - pytest.skip("Skipping tests for Hive backend") +def skip_by_backend(app_context: AppContext): + if backend() == "hive": + pytest.skip("Skipping tests for Hive backend") class BaseTestChartDataApi(SupersetTestCase): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index cc11c4df47..77ddbe1d87 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -118,8 +118,8 @@ def get_or_create_user(get_user, create_user) -> ab_models.User: @pytest.fixture(autouse=True, scope="session") def setup_sample_data() -> Any: # TODO(john-bodley): Determine a cleaner way of setting up the sample data without - # relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture which is purposely - # scoped to the function level to ensure tests remain idempotent. + # relying on `tests.integration_tests.test_app.app` leveraging an `app` fixture + # which is purposely scoped to the function level to ensure tests remain idempotent. with app.app_context(): setup_presto_if_needed() @@ -135,7 +135,6 @@ def setup_sample_data() -> Any: with app.app_context(): # drop sqlalchemy tables - db.session.commit() from sqlalchemy.ext import declarative @@ -163,12 +162,12 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore _db: Database | None = None def __call__(self) -> Database: - with app.app_context(): - if self._db is None: + if self._db is None: + with app.app_context(): self._db = get_example_database() self._load_lazy_data_to_decouple_from_session() - return self._db + return self._db def _load_lazy_data_to_decouple_from_session(self) -> None: self._db._get_sqla_engine() # type: ignore diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index 3668eae474..57c73f83d0 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -58,39 +58,36 @@ from .base_tests import SupersetTestCase class TestDashboard(SupersetTestCase): @pytest.fixture def load_dashboard(self): - with app.app_context(): - table = ( - db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() - ) - # get a slice from the allowed table - slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one() - - self.grant_public_access_to_table(table) - - pytest.hidden_dash_slug = f"hidden_dash_{random()}" - pytest.published_dash_slug = f"published_dash_{random()}" - - # Create a published and hidden dashboard and add them to the database - published_dash = Dashboard() - published_dash.dashboard_title = "Published Dashboard" - published_dash.slug = pytest.published_dash_slug - published_dash.slices = [slice] - published_dash.published = True - - hidden_dash = Dashboard() - hidden_dash.dashboard_title = "Hidden Dashboard" - hidden_dash.slug = pytest.hidden_dash_slug - hidden_dash.slices = [slice] - hidden_dash.published = False - - db.session.add(published_dash) - db.session.add(hidden_dash) - yield db.session.commit() - - self.revoke_public_access_to_table(table) - db.session.delete(published_dash) - db.session.delete(hidden_dash) - db.session.commit() + table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() + # get a slice from the allowed table + slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one() + + self.grant_public_access_to_table(table) + + pytest.hidden_dash_slug = f"hidden_dash_{random()}" + pytest.published_dash_slug = f"published_dash_{random()}" + + # Create a published and hidden dashboard and add them to the database + published_dash = Dashboard() + published_dash.dashboard_title = "Published Dashboard" + published_dash.slug = pytest.published_dash_slug + published_dash.slices = [slice] + published_dash.published = True + + hidden_dash = Dashboard() + hidden_dash.dashboard_title = "Hidden Dashboard" + hidden_dash.slug = pytest.hidden_dash_slug + hidden_dash.slices = [slice] + hidden_dash.published = False + + db.session.add(published_dash) + db.session.add(hidden_dash) + yield db.session.commit() + + self.revoke_public_access_to_table(table) + db.session.delete(published_dash) + db.session.delete(hidden_dash) + db.session.commit() def get_mock_positions(self, dash): positions = {"DASHBOARD_VERSION_KEY": "v2"} diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 7532e12164..fd63666c2b 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2088,8 +2088,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas self.assertNotEqual(result["uuid"], "") self.assertEqual(result["allowed_domains"], allowed_domains) - db.session.expire_all() - # get returns value resp = self.get_assert_metric(uri, "get_embedded") self.assertEqual(resp.status_code, 200) @@ -2110,8 +2108,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas self.assertNotEqual(result["uuid"], "") self.assertEqual(result["allowed_domains"], []) - db.session.expire_all() - # get returns changed value resp = self.get_assert_metric(uri, "get_embedded") self.assertEqual(resp.status_code, 200) @@ -2123,8 +2119,6 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas resp = self.delete_assert_metric(uri, "delete_embedded") self.assertEqual(resp.status_code, 200) - db.session.expire_all() - # get returns 404 resp = self.get_assert_metric(uri, "get_embedded") self.assertEqual(resp.status_code, 404) diff --git a/tests/integration_tests/dashboards/security/security_dataset_tests.py b/tests/integration_tests/dashboards/security/security_dataset_tests.py index f470654d61..1ca0b0bd71 100644 --- a/tests/integration_tests/dashboards/security/security_dataset_tests.py +++ b/tests/integration_tests/dashboards/security/security_dataset_tests.py @@ -37,39 +37,36 @@ from tests.integration_tests.fixtures.energy_dashboard import ( class TestDashboardDatasetSecurity(DashboardTestCase): @pytest.fixture def load_dashboard(self): - with app.app_context(): - table = ( - db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() - ) - # get a slice from the allowed table - slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one() - - self.grant_public_access_to_table(table) - - pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}" - pytest.published_dash_slug = f"published_dash_{random_slug()}" - - # Create a published and hidden dashboard and add them to the database - published_dash = Dashboard() - published_dash.dashboard_title = "Published Dashboard" - published_dash.slug = pytest.published_dash_slug - published_dash.slices = [slice] - published_dash.published = True - - hidden_dash = Dashboard() - hidden_dash.dashboard_title = "Hidden Dashboard" - hidden_dash.slug = pytest.hidden_dash_slug - hidden_dash.slices = [slice] - hidden_dash.published = False - - db.session.add(published_dash) - db.session.add(hidden_dash) - yield db.session.commit() - - self.revoke_public_access_to_table(table) - db.session.delete(published_dash) - db.session.delete(hidden_dash) - db.session.commit() + table = db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() + # get a slice from the allowed table + slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one() + + self.grant_public_access_to_table(table) + + pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}" + pytest.published_dash_slug = f"published_dash_{random_slug()}" + + # Create a published and hidden dashboard and add them to the database + published_dash = Dashboard() + published_dash.dashboard_title = "Published Dashboard" + published_dash.slug = pytest.published_dash_slug + published_dash.slices = [slice] + published_dash.published = True + + hidden_dash = Dashboard() + hidden_dash.dashboard_title = "Hidden Dashboard" + hidden_dash.slug = pytest.hidden_dash_slug + hidden_dash.slices = [slice] + hidden_dash.published = False + + db.session.add(published_dash) + db.session.add(hidden_dash) + yield db.session.commit() + + self.revoke_public_access_to_table(table) + db.session.delete(published_dash) + db.session.delete(hidden_dash) + db.session.commit() def test_dashboard_access__admin_can_access_all(self): # arrange diff --git a/tests/integration_tests/databases/commands/upload_test.py b/tests/integration_tests/databases/commands/upload_test.py index f08be099c4..695e3e8900 100644 --- a/tests/integration_tests/databases/commands/upload_test.py +++ b/tests/integration_tests/databases/commands/upload_test.py @@ -20,6 +20,7 @@ from __future__ import annotations import json import pytest +from flask.ctx import AppContext from superset import db, security_manager from superset.commands.database.exceptions import ( @@ -84,16 +85,14 @@ def get_upload_db(): return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one() -@pytest.fixture(scope="function") -def setup_csv_upload_with_context(): - with app.app_context(): - yield from _setup_csv_upload() +@pytest.fixture() +def setup_csv_upload_with_context(app_context: AppContext): + yield from _setup_csv_upload() -@pytest.fixture(scope="function") -def setup_csv_upload_with_context_schema(): - with app.app_context(): - yield from _setup_csv_upload(["public"]) +@pytest.fixture() +def setup_csv_upload_with_context_schema(app_context: AppContext): + yield from _setup_csv_upload(["public"]) @pytest.mark.usefixtures("setup_csv_upload_with_context") diff --git a/tests/integration_tests/embedded/dao_tests.py b/tests/integration_tests/embedded/dao_tests.py index 8d62fc0f6d..ca96354baf 100644 --- a/tests/integration_tests/embedded/dao_tests.py +++ b/tests/integration_tests/embedded/dao_tests.py @@ -46,6 +46,5 @@ class TestEmbeddedDashboardDAO(SupersetTestCase): def test_get_by_uuid(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) - db.session.expire_all() embedded = EmbeddedDashboardDAO.find_by_id(uuid) self.assertIsNotNone(embedded) diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index fd7c69deca..fc0b73bde8 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -173,39 +173,38 @@ def get_datasource_post() -> dict[str, Any]: @pytest.fixture() +@pytest.mark.usefixtures("app_conntext") def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: - with app.app_context(): - engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) - meta = MetaData() + engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) + meta = MetaData() - students = Table( - "students", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(255)), - Column("lastname", String(255)), - Column("ds", Date), - ) - meta.create_all(engine) + students = Table( + "students", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(255)), + Column("lastname", String(255)), + Column("ds", Date), + ) + meta.create_all(engine) - students.insert().values(name="George", ds="2021-01-01") + students.insert().values(name="George", ds="2021-01-01") - dataset = SqlaTable( - database_id=db.session.query(Database).first().id, table_name="students" - ) - column = TableColumn(table_id=dataset.id, column_name="name") - dataset.columns = [column] - db.session.add(dataset) - db.session.commit() - yield dataset + dataset = SqlaTable( + database_id=db.session.query(Database).first().id, table_name="students" + ) + column = TableColumn(table_id=dataset.id, column_name="name") + dataset.columns = [column] + db.session.add(dataset) + db.session.commit() + yield dataset - # cleanup - students_table = meta.tables.get("students") - if students_table is not None: - base = declarative_base() - # needed for sqlite - db.session.commit() - base.metadata.drop_all(engine, [students_table], checkfirst=True) - db.session.delete(dataset) - db.session.delete(column) + # cleanup + if (students_table := meta.tables.get("students")) is not None: + base = declarative_base() + # needed for sqlite db.session.commit() + base.metadata.drop_all(engine, [students_table], checkfirst=True) + db.session.delete(dataset) + db.session.delete(column) + db.session.commit() diff --git a/tests/integration_tests/fixtures/public_role.py b/tests/integration_tests/fixtures/public_role.py index 892098b401..eeb4c798d8 100644 --- a/tests/integration_tests/fixtures/public_role.py +++ b/tests/integration_tests/fixtures/public_role.py @@ -15,30 +15,29 @@ # specific language governing permissions and limitations # under the License. import pytest +from flask.ctx import AppContext from superset.extensions import db, security_manager from tests.integration_tests.test_app import app @pytest.fixture() -def public_role_like_gamma(): - with app.app_context(): - app.config["PUBLIC_ROLE_LIKE"] = "Gamma" - security_manager.sync_role_definitions() +def public_role_like_gamma(app_context: AppContext): + app.config["PUBLIC_ROLE_LIKE"] = "Gamma" + security_manager.sync_role_definitions() - yield + yield - security_manager.get_public_role().permissions = [] - db.session.commit() + security_manager.get_public_role().permissions = [] + db.session.commit() @pytest.fixture() -def public_role_like_test_role(): - with app.app_context(): - app.config["PUBLIC_ROLE_LIKE"] = "TestRole" - security_manager.sync_role_definitions() +def public_role_like_test_role(app_context: AppContext): + app.config["PUBLIC_ROLE_LIKE"] = "TestRole" + security_manager.sync_role_definitions() - yield + yield - security_manager.get_public_role().permissions = [] - db.session.commit() + security_manager.get_public_role().permissions = [] + db.session.commit() diff --git a/tests/integration_tests/fixtures/tags.py b/tests/integration_tests/fixtures/tags.py index 57fd4ec719..493b3295d8 100644 --- a/tests/integration_tests/fixtures/tags.py +++ b/tests/integration_tests/fixtures/tags.py @@ -22,12 +22,12 @@ from tests.integration_tests.test_app import app @pytest.fixture +@pytest.mark.usefixtures("app_context") def with_tagging_system_feature(): - with app.app_context(): - is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] - if not is_enabled: - app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True - register_sqla_event_listeners() - yield - app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False - clear_sqla_event_listeners() + is_enabled = app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] + if not is_enabled: + app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = True + register_sqla_event_listeners() + yield + app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False + clear_sqla_event_listeners() diff --git a/tests/integration_tests/fixtures/users.py b/tests/integration_tests/fixtures/users.py index 1dc2b8b912..6cc228d510 100644 --- a/tests/integration_tests/fixtures/users.py +++ b/tests/integration_tests/fixtures/users.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import Role, User from superset import db, security_manager @@ -23,27 +24,24 @@ from tests.integration_tests.test_app import app @pytest.fixture() -def create_gamma_sqllab_no_data(): - with app.app_context(): - gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none() - sqllab_role = ( - db.session.query(Role).filter(Role.name == "sql_lab").one_or_none() - ) +def create_gamma_sqllab_no_data(app_context: AppContext): + gamma_role = db.session.query(Role).filter(Role.name == "Gamma").one_or_none() + sqllab_role = db.session.query(Role).filter(Role.name == "sql_lab").one_or_none() - security_manager.add_user( - GAMMA_SQLLAB_NO_DATA_USERNAME, - "gamma_sqllab_no_data", - "gamma_sqllab_no_data", - "gamma_sqllab_no_d...@apache.org", - [gamma_role, sqllab_role], - password="general", - ) + security_manager.add_user( + GAMMA_SQLLAB_NO_DATA_USERNAME, + "gamma_sqllab_no_data", + "gamma_sqllab_no_data", + "gamma_sqllab_no_d...@apache.org", + [gamma_role, sqllab_role], + password="general", + ) - yield - user = ( - db.session.query(User) - .filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME) - .one_or_none() - ) - db.session.delete(user) - db.session.commit() + yield + user = ( + db.session.query(User) + .filter(User.username == GAMMA_SQLLAB_NO_DATA_USERNAME) + .one_or_none() + ) + db.session.delete(user) + db.session.commit() diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 6c3d29eb43..56531d7781 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -93,13 +93,12 @@ def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data): def create_dashboard_for_loaded_data(): - with app.app_context(): - table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database()) - slices = _create_world_bank_slices(table) - dash = _create_world_bank_dashboard(table) - slices_ids_to_delete = [slice.id for slice in slices] - dash_id_to_delete = dash.id - return dash_id_to_delete, slices_ids_to_delete + table = create_table_metadata(WB_HEALTH_POPULATION, get_example_database()) + slices = _create_world_bank_slices(table) + dash = _create_world_bank_dashboard(table) + slices_ids_to_delete = [slice.id for slice in slices] + dash_id_to_delete = dash.id + return dash_id_to_delete, slices_ids_to_delete def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: diff --git a/tests/integration_tests/migrations/87d38ad83218_migrate_can_view_and_drill_permission__tests.py b/tests/integration_tests/migrations/87d38ad83218_migrate_can_view_and_drill_permission__tests.py index e8b825a3e4..a2ccd5948b 100644 --- a/tests/integration_tests/migrations/87d38ad83218_migrate_can_view_and_drill_permission__tests.py +++ b/tests/integration_tests/migrations/87d38ad83218_migrate_can_view_and_drill_permission__tests.py @@ -16,6 +16,8 @@ # under the License. from importlib import import_module +import pytest + from superset import db from superset.migrations.shared.security_converge import ( _find_pvm, @@ -34,28 +36,28 @@ upgrade = migration_module.do_upgrade downgrade = migration_module.do_downgrade +@pytest.mark.usefixtures("app_context") def test_migration_upgrade(): - with app.app_context(): - pre_perm = PermissionView( - permission=Permission(name="can_view_and_drill"), - view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(), - ) - db.session.add(pre_perm) - db.session.commit() + pre_perm = PermissionView( + permission=Permission(name="can_view_and_drill"), + view_menu=db.session.query(ViewMenu).filter_by(name="Dashboard").one(), + ) + db.session.add(pre_perm) + db.session.commit() - assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None + assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None - upgrade(db.session) + upgrade(db.session) - assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None - assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None - assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None + assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is not None + assert _find_pvm(db.session, "Dashboard", "can_view_query") is not None + assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is None +@pytest.mark.usefixtures("app_context") def test_migration_downgrade(): - with app.app_context(): - downgrade(db.session) + downgrade(db.session) - assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None - assert _find_pvm(db.session, "Dashboard", "can_view_query") is None - assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None + assert _find_pvm(db.session, "Dashboard", "can_view_chart_as_table") is None + assert _find_pvm(db.session, "Dashboard", "can_view_query") is None + assert _find_pvm(db.session, "Dashboard", "can_view_and_drill") is not None diff --git a/tests/integration_tests/reports/alert_tests.py b/tests/integration_tests/reports/alert_tests.py index 6664d65a9b..b364c72015 100644 --- a/tests/integration_tests/reports/alert_tests.py +++ b/tests/integration_tests/reports/alert_tests.py @@ -20,6 +20,7 @@ from typing import Optional, Union import pandas as pd import pytest +from flask.ctx import AppContext from pytest_mock import MockFixture from superset.commands.report.exceptions import AlertQueryError @@ -61,43 +62,40 @@ def test_execute_query_as_report_executor( config: list[ExecutorType], expected_result: Union[tuple[ExecutorType, str], Exception], mocker: MockFixture, - app_context: None, + app_context: AppContext, get_user, ) -> None: from superset.commands.report.alert import AlertCommand from superset.reports.models import ReportSchedule - with app.app_context(): - original_config = app.config["ALERT_REPORTS_EXECUTE_AS"] - app.config["ALERT_REPORTS_EXECUTE_AS"] = config - owners = [get_user(owner_name) for owner_name in owner_names] - report_schedule = ReportSchedule( - created_by=get_user(creator_name) if creator_name else None, - owners=owners, - type=ReportScheduleType.ALERT, - description="description", - crontab="0 9 * * *", - creation_method=ReportCreationMethod.ALERTS_REPORTS, - sql="SELECT 1", - grace_period=14400, - working_timeout=3600, - database=get_example_database(), - validator_config_json='{"op": "==", "threshold": 1}', - ) - command = AlertCommand(report_schedule=report_schedule) - override_user_mock = mocker.patch( - "superset.commands.report.alert.override_user" - ) - cm = ( - pytest.raises(type(expected_result)) - if isinstance(expected_result, Exception) - else nullcontext() - ) - with cm: - command.run() - assert override_user_mock.call_args[0][0].username == expected_result - - app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config + original_config = app.config["ALERT_REPORTS_EXECUTE_AS"] + app.config["ALERT_REPORTS_EXECUTE_AS"] = config + owners = [get_user(owner_name) for owner_name in owner_names] + report_schedule = ReportSchedule( + created_by=get_user(creator_name) if creator_name else None, + owners=owners, + type=ReportScheduleType.ALERT, + description="description", + crontab="0 9 * * *", + creation_method=ReportCreationMethod.ALERTS_REPORTS, + sql="SELECT 1", + grace_period=14400, + working_timeout=3600, + database=get_example_database(), + validator_config_json='{"op": "==", "threshold": 1}', + ) + command = AlertCommand(report_schedule=report_schedule) + override_user_mock = mocker.patch("superset.commands.report.alert.override_user") + cm = ( + pytest.raises(type(expected_result)) + if isinstance(expected_result, Exception) + else nullcontext() + ) + with cm: + command.run() + assert override_user_mock.call_args[0][0].username == expected_result + + app.config["ALERT_REPORTS_EXECUTE_AS"] = original_config def test_execute_query_succeeded_no_retry( diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index 79102654d5..6fde3d2369 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -23,6 +23,7 @@ from uuid import uuid4 import pytest from flask import current_app +from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from flask_sqlalchemy import BaseQuery from freezegun import freeze_time @@ -162,190 +163,219 @@ def create_test_table_context(database: Database): @pytest.fixture() def create_report_email_chart(): - with app.app_context(): - chart = db.session.query(Slice).first() - report_schedule = create_report_notification( - email_target="tar...@email.com", chart=chart - ) - yield report_schedule + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + email_target="tar...@email.com", chart=chart + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_chart_alpha_owner(get_user): - with app.app_context(): - owners = [get_user("alpha")] - chart = db.session.query(Slice).first() - report_schedule = create_report_notification( - email_target="tar...@email.com", chart=chart, owners=owners - ) - yield report_schedule + owners = [get_user("alpha")] + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + email_target="tar...@email.com", chart=chart, owners=owners + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_chart_force_screenshot(): - with app.app_context(): - chart = db.session.query(Slice).first() - report_schedule = create_report_notification( - email_target="tar...@email.com", chart=chart, force_screenshot=True - ) - yield report_schedule + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + email_target="tar...@email.com", chart=chart, force_screenshot=True + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_chart_with_csv(): - with app.app_context(): - chart = db.session.query(Slice).first() - chart.query_context = '{"mock": "query_context"}' - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_format=ReportDataFormat.CSV, - ) - yield report_schedule - cleanup_report_schedule(report_schedule) + chart = db.session.query(Slice).first() + chart.query_context = '{"mock": "query_context"}' + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_format=ReportDataFormat.CSV, + ) + yield report_schedule + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_chart_with_text(): - with app.app_context(): - chart = db.session.query(Slice).first() - chart.query_context = '{"mock": "query_context"}' - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_format=ReportDataFormat.TEXT, - ) - yield report_schedule - cleanup_report_schedule(report_schedule) + chart = db.session.query(Slice).first() + chart.query_context = '{"mock": "query_context"}' + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_format=ReportDataFormat.TEXT, + ) + yield report_schedule + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_chart_with_csv_no_query_context(): - with app.app_context(): - chart = db.session.query(Slice).first() - chart.query_context = None - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_format=ReportDataFormat.CSV, - name="report_csv_no_query_context", - ) - yield report_schedule - cleanup_report_schedule(report_schedule) + chart = db.session.query(Slice).first() + chart.query_context = None + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_format=ReportDataFormat.CSV, + name="report_csv_no_query_context", + ) + yield report_schedule + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_dashboard(): - with app.app_context(): - dashboard = db.session.query(Dashboard).first() - report_schedule = create_report_notification( - email_target="tar...@email.com", dashboard=dashboard - ) - yield report_schedule + dashboard = db.session.query(Dashboard).first() + report_schedule = create_report_notification( + email_target="tar...@email.com", dashboard=dashboard + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_email_dashboard_force_screenshot(): - with app.app_context(): - dashboard = db.session.query(Dashboard).first() - report_schedule = create_report_notification( - email_target="tar...@email.com", dashboard=dashboard, force_screenshot=True - ) - yield report_schedule + dashboard = db.session.query(Dashboard).first() + report_schedule = create_report_notification( + email_target="tar...@email.com", dashboard=dashboard, force_screenshot=True + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_slack_chart(): - with app.app_context(): - chart = db.session.query(Slice).first() - report_schedule = create_report_notification( - slack_channel="slack_channel", chart=chart - ) - yield report_schedule + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + slack_channel="slack_channel", chart=chart + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_slack_chart_with_csv(): - with app.app_context(): - chart = db.session.query(Slice).first() - chart.query_context = '{"mock": "query_context"}' - report_schedule = create_report_notification( - slack_channel="slack_channel", - chart=chart, - report_format=ReportDataFormat.CSV, - ) - yield report_schedule + chart = db.session.query(Slice).first() + chart.query_context = '{"mock": "query_context"}' + report_schedule = create_report_notification( + slack_channel="slack_channel", + chart=chart, + report_format=ReportDataFormat.CSV, + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_slack_chart_with_text(): - with app.app_context(): - chart = db.session.query(Slice).first() - chart.query_context = '{"mock": "query_context"}' - report_schedule = create_report_notification( - slack_channel="slack_channel", - chart=chart, - report_format=ReportDataFormat.TEXT, - ) - yield report_schedule + chart = db.session.query(Slice).first() + chart.query_context = '{"mock": "query_context"}' + report_schedule = create_report_notification( + slack_channel="slack_channel", + chart=chart, + report_format=ReportDataFormat.TEXT, + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_report_slack_chart_working(): - with app.app_context(): - chart = db.session.query(Slice).first() - report_schedule = create_report_notification( - slack_channel="slack_channel", chart=chart - ) - report_schedule.last_state = ReportState.WORKING - report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) - report_schedule.last_value = None - report_schedule.last_value_row_json = None - db.session.commit() - log = ReportExecutionLog( - scheduled_dttm=report_schedule.last_eval_dttm, - start_dttm=report_schedule.last_eval_dttm, - end_dttm=report_schedule.last_eval_dttm, - value=report_schedule.last_value, - value_row_json=report_schedule.last_value_row_json, - state=ReportState.WORKING, - report_schedule=report_schedule, - uuid=uuid4(), - ) - db.session.add(log) - db.session.commit() + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + slack_channel="slack_channel", chart=chart + ) + report_schedule.last_state = ReportState.WORKING + report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) + report_schedule.last_value = None + report_schedule.last_value_row_json = None + db.session.commit() + log = ReportExecutionLog( + scheduled_dttm=report_schedule.last_eval_dttm, + start_dttm=report_schedule.last_eval_dttm, + end_dttm=report_schedule.last_eval_dttm, + value=report_schedule.last_value, + value_row_json=report_schedule.last_value_row_json, + state=ReportState.WORKING, + report_schedule=report_schedule, + uuid=uuid4(), + ) + db.session.add(log) + db.session.commit() - yield report_schedule + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture() def create_alert_slack_chart_success(): - with app.app_context(): - chart = db.session.query(Slice).first() + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + slack_channel="slack_channel", + chart=chart, + report_type=ReportScheduleType.ALERT, + ) + report_schedule.last_state = ReportState.SUCCESS + report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) + + log = ReportExecutionLog( + report_schedule=report_schedule, + state=ReportState.SUCCESS, + start_dttm=report_schedule.last_eval_dttm, + end_dttm=report_schedule.last_eval_dttm, + scheduled_dttm=report_schedule.last_eval_dttm, + ) + db.session.add(log) + db.session.commit() + yield report_schedule + + cleanup_report_schedule(report_schedule) + + +@pytest.fixture( + params=[ + "alert1", + ] +) +def create_alert_slack_chart_grace(request): + param_config = { + "alert1": { + "sql": "SELECT count(*) from test_table", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 10}', + }, + } + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): report_schedule = create_report_notification( slack_channel="slack_channel", chart=chart, report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], ) - report_schedule.last_state = ReportState.SUCCESS + report_schedule.last_state = ReportState.GRACE report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) log = ReportExecutionLog( @@ -362,51 +392,6 @@ def create_alert_slack_chart_success(): cleanup_report_schedule(report_schedule) -@pytest.fixture( - params=[ - "alert1", - ] -) -def create_alert_slack_chart_grace(request): - param_config = { - "alert1": { - "sql": "SELECT count(*) from test_table", - "validator_type": ReportScheduleValidatorType.OPERATOR, - "validator_config_json": '{"op": "<", "threshold": 10}', - }, - } - with app.app_context(): - chart = db.session.query(Slice).first() - example_database = get_example_database() - with create_test_table_context(example_database): - report_schedule = create_report_notification( - slack_channel="slack_channel", - chart=chart, - report_type=ReportScheduleType.ALERT, - database=example_database, - sql=param_config[request.param]["sql"], - validator_type=param_config[request.param]["validator_type"], - validator_config_json=param_config[request.param][ - "validator_config_json" - ], - ) - report_schedule.last_state = ReportState.GRACE - report_schedule.last_eval_dttm = datetime(2020, 1, 1, 0, 0) - - log = ReportExecutionLog( - report_schedule=report_schedule, - state=ReportState.SUCCESS, - start_dttm=report_schedule.last_eval_dttm, - end_dttm=report_schedule.last_eval_dttm, - scheduled_dttm=report_schedule.last_eval_dttm, - ) - db.session.add(log) - db.session.commit() - yield report_schedule - - cleanup_report_schedule(report_schedule) - - @pytest.fixture( params=[ "alert1", @@ -462,25 +447,22 @@ def create_alert_email_chart(request): "validator_config_json": '{"op": ">", "threshold": 54.999}', }, } - with app.app_context(): - chart = db.session.query(Slice).first() - example_database = get_example_database() - with create_test_table_context(example_database): - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_type=ReportScheduleType.ALERT, - database=example_database, - sql=param_config[request.param]["sql"], - validator_type=param_config[request.param]["validator_type"], - validator_config_json=param_config[request.param][ - "validator_config_json" - ], - force_screenshot=True, - ) - yield report_schedule + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], + force_screenshot=True, + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture( @@ -544,24 +526,21 @@ def create_no_alert_email_chart(request): "validator_config_json": '{"op": ">", "threshold": 0}', }, } - with app.app_context(): - chart = db.session.query(Slice).first() - example_database = get_example_database() - with create_test_table_context(example_database): - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_type=ReportScheduleType.ALERT, - database=example_database, - sql=param_config[request.param]["sql"], - validator_type=param_config[request.param]["validator_type"], - validator_config_json=param_config[request.param][ - "validator_config_json" - ], - ) - yield report_schedule + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture(params=["alert1", "alert2"]) @@ -578,28 +557,25 @@ def create_mul_alert_email_chart(request): "validator_config_json": '{"op": "<", "threshold": 10}', }, } - with app.app_context(): - chart = db.session.query(Slice).first() - example_database = get_example_database() - with create_test_table_context(example_database): - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_type=ReportScheduleType.ALERT, - database=example_database, - sql=param_config[request.param]["sql"], - validator_type=param_config[request.param]["validator_type"], - validator_config_json=param_config[request.param][ - "validator_config_json" - ], - ) - yield report_schedule + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.fixture(params=["alert1", "alert2"]) -def create_invalid_sql_alert_email_chart(request): +def create_invalid_sql_alert_email_chart(request, app_context: AppContext): param_config = { "alert1": { "sql": "SELECT 'string' ", @@ -612,25 +588,22 @@ def create_invalid_sql_alert_email_chart(request): "validator_config_json": '{"op": "<", "threshold": 10}', }, } - with app.app_context(): - chart = db.session.query(Slice).first() - example_database = get_example_database() - with create_test_table_context(example_database): - report_schedule = create_report_notification( - email_target="tar...@email.com", - chart=chart, - report_type=ReportScheduleType.ALERT, - database=example_database, - sql=param_config[request.param]["sql"], - validator_type=param_config[request.param]["validator_type"], - validator_config_json=param_config[request.param][ - "validator_config_json" - ], - grace_period=60 * 60, - ) - yield report_schedule + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + report_schedule = create_report_notification( + email_target="tar...@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], + grace_period=60 * 60, + ) + yield report_schedule - cleanup_report_schedule(report_schedule) + cleanup_report_schedule(report_schedule) @pytest.mark.usefixtures( @@ -835,7 +808,8 @@ def test_email_chart_report_dry_run( @pytest.mark.usefixtures( - "load_birth_names_dashboard_with_slices", "create_report_email_chart_with_csv" + "load_birth_names_dashboard_with_slices", + "create_report_email_chart_with_csv", ) @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") @@ -923,7 +897,8 @@ def test_email_chart_report_schedule_with_csv_no_query_context( @pytest.mark.usefixtures( - "load_birth_names_dashboard_with_slices", "create_report_email_chart_with_text" + "load_birth_names_dashboard_with_slices", + "create_report_email_chart_with_text", ) @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") @@ -1545,7 +1520,8 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart): @pytest.mark.usefixtures( - "load_birth_names_dashboard_with_slices", "create_report_slack_chart" + "load_birth_names_dashboard_with_slices", + "create_report_slack_chart", ) @patch("superset.utils.slack.WebClient") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") @@ -1571,7 +1547,7 @@ def test_slack_token_callable_chart_report( assert_log(ReportState.SUCCESS) -@pytest.mark.usefixtures("create_no_alert_email_chart") +@pytest.mark.usefixtures("app_context") def test_email_chart_no_alert(create_no_alert_email_chart): """ ExecuteReport Command: Test chart email no alert @@ -1583,7 +1559,7 @@ def test_email_chart_no_alert(create_no_alert_email_chart): assert_log(ReportState.NOOP) -@pytest.mark.usefixtures("create_mul_alert_email_chart") +@pytest.mark.usefixtures("app_context") def test_email_mul_alert(create_mul_alert_email_chart): """ ExecuteReport Command: Test chart email multiple rows @@ -1824,7 +1800,6 @@ def test_email_disable_screenshot(email_mock, create_alert_email_chart): assert_log(ReportState.SUCCESS) -@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart") @patch("superset.reports.notifications.email.send_email_smtp") def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart): """ @@ -1841,7 +1816,6 @@ def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart): assert_log(ReportState.ERROR) -@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart") @patch("superset.reports.notifications.email.send_email_smtp") def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): """ @@ -1884,7 +1858,6 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): ) -@pytest.mark.usefixtures("create_invalid_sql_alert_email_chart") @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_grace_period_error_flap( diff --git a/tests/integration_tests/reports/scheduler_tests.py b/tests/integration_tests/reports/scheduler_tests.py index ee93ef48a4..ae25b575ae 100644 --- a/tests/integration_tests/reports/scheduler_tests.py +++ b/tests/integration_tests/reports/scheduler_tests.py @@ -35,150 +35,144 @@ def owners(get_user) -> list[User]: return [get_user("admin")] -@pytest.mark.usefixtures("owners") +@pytest.mark.usefixtures("app_context") @patch("superset.tasks.scheduler.execute.apply_async") def test_scheduler_celery_timeout_ny(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ - with app.app_context(): - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name="report", - crontab="0 4 * * *", - timezone="America/New_York", - owners=owners, - ) - - with freeze_time("2020-01-01T09:00:00Z"): - scheduler() - assert execute_mock.call_args[1]["soft_time_limit"] == 3601 - assert execute_mock.call_args[1]["time_limit"] == 3610 - db.session.delete(report_schedule) - db.session.commit() - - -@pytest.mark.usefixtures("owners") + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name="report", + crontab="0 4 * * *", + timezone="America/New_York", + owners=owners, + ) + + with freeze_time("2020-01-01T09:00:00Z"): + scheduler() + assert execute_mock.call_args[1]["soft_time_limit"] == 3601 + assert execute_mock.call_args[1]["time_limit"] == 3610 + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.mark.usefixtures("app_context") @patch("superset.tasks.scheduler.execute.apply_async") def test_scheduler_celery_no_timeout_ny(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ - with app.app_context(): - app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name="report", - crontab="0 4 * * *", - timezone="America/New_York", - owners=owners, - ) - - with freeze_time("2020-01-01T09:00:00Z"): - scheduler() - assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} - db.session.delete(report_schedule) - db.session.commit() - app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True - - -@pytest.mark.usefixtures("owners") + app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name="report", + crontab="0 4 * * *", + timezone="America/New_York", + owners=owners, + ) + + with freeze_time("2020-01-01T09:00:00Z"): + scheduler() + assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} + db.session.delete(report_schedule) + db.session.commit() + app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True + + +@pytest.mark.usefixtures("app_context") @patch("superset.tasks.scheduler.execute.apply_async") def test_scheduler_celery_timeout_utc(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ - with app.app_context(): - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name="report", - crontab="0 9 * * *", - timezone="UTC", - owners=owners, - ) - - with freeze_time("2020-01-01T09:00:00Z"): - scheduler() - assert execute_mock.call_args[1]["soft_time_limit"] == 3601 - assert execute_mock.call_args[1]["time_limit"] == 3610 - db.session.delete(report_schedule) - db.session.commit() - - -@pytest.mark.usefixtures("owners") + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name="report", + crontab="0 9 * * *", + timezone="UTC", + owners=owners, + ) + + with freeze_time("2020-01-01T09:00:00Z"): + scheduler() + assert execute_mock.call_args[1]["soft_time_limit"] == 3601 + assert execute_mock.call_args[1]["time_limit"] == 3610 + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.mark.usefixtures("app_context") @patch("superset.tasks.scheduler.execute.apply_async") def test_scheduler_celery_no_timeout_utc(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ - with app.app_context(): - app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name="report", - crontab="0 9 * * *", - timezone="UTC", - owners=owners, - ) - - with freeze_time("2020-01-01T09:00:00Z"): - scheduler() - assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} - db.session.delete(report_schedule) - db.session.commit() - app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True - - -@pytest.mark.usefixtures("owners") + app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = False + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name="report", + crontab="0 9 * * *", + timezone="UTC", + owners=owners, + ) + + with freeze_time("2020-01-01T09:00:00Z"): + scheduler() + assert execute_mock.call_args[1] == {"eta": FakeDatetime(2020, 1, 1, 9, 0)} + db.session.delete(report_schedule) + db.session.commit() + app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True + + +@pytest.mark.usefixtures("app_context") @patch("superset.tasks.scheduler.is_feature_enabled") @patch("superset.tasks.scheduler.execute.apply_async") def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled, owners): """ Reports scheduler: Test scheduler with feature flag off """ - with app.app_context(): - is_feature_enabled.return_value = False - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name="report", - crontab="0 9 * * *", - timezone="UTC", - owners=owners, - ) - - with freeze_time("2020-01-01T09:00:00Z"): - scheduler() - execute_mock.assert_not_called() - db.session.delete(report_schedule) - db.session.commit() - - -@pytest.mark.usefixtures("owners") + is_feature_enabled.return_value = False + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name="report", + crontab="0 9 * * *", + timezone="UTC", + owners=owners, + ) + + with freeze_time("2020-01-01T09:00:00Z"): + scheduler() + execute_mock.assert_not_called() + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.mark.usefixtures("app_context") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run") @patch("superset.tasks.scheduler.execute.update_state") def test_execute_task(update_state_mock, command_mock, init_mock, owners): from superset.commands.report.exceptions import ReportScheduleUnexpectedError - with app.app_context(): - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name=f"report-{randint(0,1000)}", - crontab="0 4 * * *", - timezone="America/New_York", - owners=owners, - ) - init_mock.return_value = None - command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error") - with freeze_time("2020-01-01T09:00:00Z"): - execute(report_schedule.id) - update_state_mock.assert_called_with(state="FAILURE") + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name=f"report-{randint(0,1000)}", + crontab="0 4 * * *", + timezone="America/New_York", + owners=owners, + ) + init_mock.return_value = None + command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error") + with freeze_time("2020-01-01T09:00:00Z"): + execute(report_schedule.id) + update_state_mock.assert_called_with(state="FAILURE") - db.session.delete(report_schedule) - db.session.commit() + db.session.delete(report_schedule) + db.session.commit() -@pytest.mark.usefixtures("owners") +@pytest.mark.usefixtures("app_context") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.__init__") @patch("superset.commands.report.execute.AsyncExecuteReportScheduleCommand.run") @patch("superset.tasks.scheduler.execute.update_state") @@ -188,23 +182,22 @@ def test_execute_task_with_command_exception( ): from superset.commands.exceptions import CommandException - with app.app_context(): - report_schedule = insert_report_schedule( - type=ReportScheduleType.ALERT, - name=f"report-{randint(0,1000)}", - crontab="0 4 * * *", - timezone="America/New_York", - owners=owners, + report_schedule = insert_report_schedule( + type=ReportScheduleType.ALERT, + name=f"report-{randint(0,1000)}", + crontab="0 4 * * *", + timezone="America/New_York", + owners=owners, + ) + init_mock.return_value = None + command_mock.side_effect = CommandException("Unexpected error") + with freeze_time("2020-01-01T09:00:00Z"): + execute(report_schedule.id) + update_state_mock.assert_called_with(state="FAILURE") + logger_mock.exception.assert_called_with( + "A downstream exception occurred while generating a report: None. Unexpected error", + exc_info=True, ) - init_mock.return_value = None - command_mock.side_effect = CommandException("Unexpected error") - with freeze_time("2020-01-01T09:00:00Z"): - execute(report_schedule.id) - update_state_mock.assert_called_with(state="FAILURE") - logger_mock.exception.assert_called_with( - "A downstream exception occurred while generating a report: None. Unexpected error", - exc_info=True, - ) - - db.session.delete(report_schedule) - db.session.commit() + + db.session.delete(report_schedule) + db.session.commit() diff --git a/tests/integration_tests/security/analytics_db_safety_tests.py b/tests/integration_tests/security/analytics_db_safety_tests.py index 9c40050c0a..3b686497a3 100644 --- a/tests/integration_tests/security/analytics_db_safety_tests.py +++ b/tests/integration_tests/security/analytics_db_safety_tests.py @@ -84,10 +84,9 @@ from tests.integration_tests.test_app import app def test_check_sqlalchemy_uri( sqlalchemy_uri: str, error: bool, error_message: Optional[str] ): - with app.app_context(): - if error: - with pytest.raises(SupersetSecurityException) as excinfo: - check_sqlalchemy_uri(make_url(sqlalchemy_uri)) - assert str(excinfo.value) == error_message - else: + if error: + with pytest.raises(SupersetSecurityException) as excinfo: check_sqlalchemy_uri(make_url(sqlalchemy_uri)) + assert str(excinfo.value) == error_message + else: + check_sqlalchemy_uri(make_url(sqlalchemy_uri)) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 0359317e3a..23bdbe4963 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -24,7 +24,7 @@ import pytest import numpy as np import pandas as pd -from flask import Flask +from flask.ctx import AppContext from pytest_mock import MockFixture from sqlalchemy.sql import text from sqlalchemy.sql.elements import TextClause @@ -598,26 +598,25 @@ class TestDatabaseModel(SupersetTestCase): db.session.commit() -@pytest.fixture -def text_column_table(): - with app.app_context(): - table = SqlaTable( - table_name="text_column_table", - sql=( - "SELECT 'foo' as foo " - "UNION SELECT '' " - "UNION SELECT NULL " - "UNION SELECT 'null' " - "UNION SELECT '\"text in double quotes\"' " - "UNION SELECT '''text in single quotes''' " - "UNION SELECT 'double quotes \" in text' " - "UNION SELECT 'single quotes '' in text' " - ), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=table) - SqlMetric(metric_name="count", expression="count(*)", table=table) - yield table +@pytest.fixture() +def text_column_table(app_context: AppContext): + table = SqlaTable( + table_name="text_column_table", + sql=( + "SELECT 'foo' as foo " + "UNION SELECT '' " + "UNION SELECT NULL " + "UNION SELECT 'null' " + "UNION SELECT '\"text in double quotes\"' " + "UNION SELECT '''text in single quotes''' " + "UNION SELECT 'double quotes \" in text' " + "UNION SELECT 'single quotes '' in text' " + ), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=table) + SqlMetric(metric_name="count", expression="count(*)", table=table) + yield table def test_values_for_column_on_text_column(text_column_table): @@ -836,6 +835,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset): ) +@pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( "row,dimension,result", [ @@ -857,7 +857,6 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset): ], ) def test__normalize_prequery_result_type( - app_context: Flask, mocker: MockFixture, row: pd.Series, dimension: str, @@ -927,7 +926,8 @@ def test__normalize_prequery_result_type( assert normalized == result -def test__temporal_range_operator_in_adhoc_filter(app_context, physical_dataset): +@pytest.mark.usefixtures("app_context") +def test__temporal_range_operator_in_adhoc_filter(physical_dataset): result = physical_dataset.query( { "columns": ["col1", "col2"], diff --git a/tests/integration_tests/utils/core_tests.py b/tests/integration_tests/utils/core_tests.py index 1a2fa6a521..6954a0610c 100644 --- a/tests/integration_tests/utils/core_tests.py +++ b/tests/integration_tests/utils/core_tests.py @@ -82,5 +82,4 @@ def test_form_data_to_adhoc_incorrect_clause_type(): form_data = {"where": "1 = 1", "having": "count(*) > 1"} with pytest.raises(ValueError): - with app.app_context(): - form_data_to_adhoc(form_data, "foobar") + form_data_to_adhoc(form_data, "foobar") diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 697b858542..18e2ab8017 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -685,20 +685,19 @@ class TestUtils(SupersetTestCase): self.assertIsNotNone(parse_js_uri_path_item("item")) def test_get_stacktrace(self): - with app.app_context(): - app.config["SHOW_STACKTRACE"] = True - try: - raise Exception("NONONO!") - except Exception: - stacktrace = get_stacktrace() - self.assertIn("NONONO", stacktrace) - - app.config["SHOW_STACKTRACE"] = False - try: - raise Exception("NONONO!") - except Exception: - stacktrace = get_stacktrace() - assert stacktrace is None + app.config["SHOW_STACKTRACE"] = True + try: + raise Exception("NONONO!") + except Exception: + stacktrace = get_stacktrace() + self.assertIn("NONONO", stacktrace) + + app.config["SHOW_STACKTRACE"] = False + try: + raise Exception("NONONO!") + except Exception: + stacktrace = get_stacktrace() + assert stacktrace is None def test_split(self): self.assertEqual(list(split("a b")), ["a", "b"]) @@ -839,9 +838,8 @@ class TestUtils(SupersetTestCase): ) def test_get_form_data_default(self) -> None: - with app.test_request_context(): - form_data, slc = get_form_data() - self.assertEqual(slc, None) + form_data, slc = get_form_data() + self.assertEqual(slc, None) def test_get_form_data_request_args(self) -> None: with app.test_request_context( diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 40bb7a019a..d27e5a8739 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -411,72 +411,71 @@ def test_delete_ssh_tunnel( """ Test that we can delete SSH Tunnel """ - with app.app_context(): - from superset.daos.database import DatabaseDAO - from superset.databases.api import DatabaseRestApi - from superset.databases.ssh_tunnel.models import SSHTunnel - from superset.models.core import Database + from superset.daos.database import DatabaseDAO + from superset.databases.api import DatabaseRestApi + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database - DatabaseRestApi.datamodel.session = session + DatabaseRestApi.datamodel.session = session - # create table for databases - Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member - # Create our Database - database = Database( - database_name="my_database", - sqlalchemy_uri="gsheets://", - encrypted_extra=json.dumps( - { - "service_account_info": { - "type": "service_account", - "project_id": "black-sanctum-314419", - "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", - "private_key": "SECRET", - "client_email": "google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com", - "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", - }, - } - ), - ) - db.session.add(database) - db.session.commit() - - # mock the lookup so that we don't need to include the driver - mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") - mocker.patch("superset.utils.log.DBEventLogger.log") - mocker.patch( - "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", - return_value=True, - ) + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + db.session.add(database) + db.session.commit() - # Create our SSHTunnel - tunnel = SSHTunnel( - database_id=1, - database=database, - ) + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", + return_value=True, + ) - db.session.add(tunnel) - db.session.commit() + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) - # Get our recently created SSHTunnel - response_tunnel = DatabaseDAO.get_ssh_tunnel(1) - assert response_tunnel - assert isinstance(response_tunnel, SSHTunnel) - assert 1 == response_tunnel.database_id + db.session.add(tunnel) + db.session.commit() - # Delete the recently created SSHTunnel - response_delete_tunnel = client.delete( - f"/api/v1/database/{database.id}/ssh_tunnel/" - ) - assert response_delete_tunnel.json["message"] == "OK" + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id - response_tunnel = DatabaseDAO.get_ssh_tunnel(1) - assert response_tunnel is None + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete( + f"/api/v1/database/{database.id}/ssh_tunnel/" + ) + assert response_delete_tunnel.json["message"] == "OK" + + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel is None def test_delete_ssh_tunnel_not_found( @@ -489,70 +488,69 @@ def test_delete_ssh_tunnel_not_found( """ Test that we cannot delete a tunnel that does not exist """ - with app.app_context(): - from superset.daos.database import DatabaseDAO - from superset.databases.api import DatabaseRestApi - from superset.databases.ssh_tunnel.models import SSHTunnel - from superset.models.core import Database + from superset.daos.database import DatabaseDAO + from superset.databases.api import DatabaseRestApi + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database - DatabaseRestApi.datamodel.session = session + DatabaseRestApi.datamodel.session = session - # create table for databases - Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member - # Create our Database - database = Database( - database_name="my_database", - sqlalchemy_uri="gsheets://", - encrypted_extra=json.dumps( - { - "service_account_info": { - "type": "service_account", - "project_id": "black-sanctum-314419", - "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", - "private_key": "SECRET", - "client_email": "google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com", - "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", - }, - } - ), - ) - db.session.add(database) - db.session.commit() - - # mock the lookup so that we don't need to include the driver - mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") - mocker.patch("superset.utils.log.DBEventLogger.log") - mocker.patch( - "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", - return_value=True, - ) + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + db.session.add(database) + db.session.commit() - # Create our SSHTunnel - tunnel = SSHTunnel( - database_id=1, - database=database, - ) + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", + return_value=True, + ) - db.session.add(tunnel) - db.session.commit() + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) - # Delete the recently created SSHTunnel - response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") - assert response_delete_tunnel.json["message"] == "Not found" + db.session.add(tunnel) + db.session.commit() - # Get our recently created SSHTunnel - response_tunnel = DatabaseDAO.get_ssh_tunnel(1) - assert response_tunnel - assert isinstance(response_tunnel, SSHTunnel) - assert 1 == response_tunnel.database_id + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "Not found" - response_tunnel = DatabaseDAO.get_ssh_tunnel(2) - assert response_tunnel is None + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id + + response_tunnel = DatabaseDAO.get_ssh_tunnel(2) + assert response_tunnel is None def test_apply_dynamic_database_filter( @@ -568,88 +566,87 @@ def test_apply_dynamic_database_filter( defining a filter function and patching the config to get the filtered results. """ - with app.app_context(): - from superset.daos.database import DatabaseDAO - from superset.databases.api import DatabaseRestApi - from superset.databases.ssh_tunnel.models import SSHTunnel - from superset.models.core import Database + from superset.daos.database import DatabaseDAO + from superset.databases.api import DatabaseRestApi + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session - DatabaseRestApi.datamodel.session = session + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member - # create table for databases - Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + # Create our First Database + database = Database( + database_name="first-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + db.session.add(database) + db.session.commit() - # Create our First Database - database = Database( - database_name="first-database", - sqlalchemy_uri="gsheets://", - encrypted_extra=json.dumps( - { - "metadata_params": {}, - "engine_params": {}, - "metadata_cache_timeout": {}, - "schemas_allowed_for_file_upload": [], - } - ), - ) - db.session.add(database) - db.session.commit() - - # Create our Second Database - database = Database( - database_name="second-database", - sqlalchemy_uri="gsheets://", - encrypted_extra=json.dumps( - { - "metadata_params": {}, - "engine_params": {}, - "metadata_cache_timeout": {}, - "schemas_allowed_for_file_upload": [], - } - ), - ) - db.session.add(database) - db.session.commit() - - # mock the lookup so that we don't need to include the driver - mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") - mocker.patch("superset.utils.log.DBEventLogger.log") - mocker.patch( - "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", - return_value=False, - ) + # Create our Second Database + database = Database( + database_name="second-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + db.session.add(database) + db.session.commit() - def _base_filter(query): - from superset.models.core import Database + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.commands.database.ssh_tunnel.delete.is_feature_enabled", + return_value=False, + ) + + def _base_filter(query): + from superset.models.core import Database - return query.filter(Database.database_name.startswith("second")) + return query.filter(Database.database_name.startswith("second")) - # Create a mock object - base_filter_mock = Mock(side_effect=_base_filter) + # Create a mock object + base_filter_mock = Mock(side_effect=_base_filter) - # Get our recently created Databases - response_databases = DatabaseDAO.find_all() - assert response_databases - expected_db_names = ["first-database", "second-database"] - actual_db_names = [db.database_name for db in response_databases] - assert actual_db_names == expected_db_names + # Get our recently created Databases + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["first-database", "second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names - # Ensure that the filter has not been called because it's not in our config - assert base_filter_mock.call_count == 0 + # Ensure that the filter has not been called because it's not in our config + assert base_filter_mock.call_count == 0 - original_config = current_app.config.copy() - original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock} + original_config = current_app.config.copy() + original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock} - mocker.patch("superset.views.filters.current_app.config", new=original_config) - # Get filtered list - response_databases = DatabaseDAO.find_all() - assert response_databases - expected_db_names = ["second-database"] - actual_db_names = [db.database_name for db in response_databases] - assert actual_db_names == expected_db_names + mocker.patch("superset.views.filters.current_app.config", new=original_config) + # Get filtered list + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names - # Ensure that the filter has been called once - assert base_filter_mock.call_count == 1 + # Ensure that the filter has been called once + assert base_filter_mock.call_count == 1 def test_oauth2_happy_path(