This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch postgres-catalog in repository https://gitbox.apache.org/repos/asf/superset.git
commit 0dd60df7f7df4fecfa169ba7114cf1cb5c3b01d9 Author: Beto Dealmeida <[email protected]> AuthorDate: Sat Apr 27 13:17:44 2024 -0400 feat(SIP-95): permissions for catalogs --- superset/commands/database/create.py | 33 +- superset/commands/database/tables.py | 31 +- superset/commands/database/update.py | 178 +++++++-- superset/config.py | 6 +- superset/connectors/sqla/models.py | 25 +- superset/constants.py | 1 + superset/databases/api.py | 81 +++- superset/databases/filters.py | 13 +- superset/databases/schemas.py | 15 + superset/db_engine_specs/base.py | 20 +- superset/db_engine_specs/bigquery.py | 4 +- superset/db_engine_specs/clickhouse.py | 4 +- superset/db_engine_specs/impala.py | 7 +- superset/db_engine_specs/postgres.py | 34 +- superset/db_engine_specs/presto.py | 6 +- superset/db_engine_specs/snowflake.py | 6 +- superset/extensions/metadb.py | 5 - ...-52_58d051681a3b_add_catalog_perm_to_tables.py} | 48 +-- superset/models/core.py | 69 +++- superset/security/manager.py | 422 +++++++++++++++++---- superset/utils/cache.py | 10 +- superset/utils/core.py | 15 +- superset/utils/filters.py | 2 + superset/views/database/mixins.py | 28 +- 24 files changed, 850 insertions(+), 213 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 4903938eb9..13d44b04e7 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -97,12 +97,35 @@ class CreateDatabaseCommand(BaseCommand): db.session.commit() - # adding a new database we always want to force refresh schema list - schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names( + cache=False, + ssh_tunnel=ssh_tunnel, ) + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database, catalog), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) except ( SSHTunnelInvalidError, diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py index 055c0be9ae..b16fcfc504 100644 --- a/superset/commands/database/tables.py +++ b/superset/commands/database/tables.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import logging from typing import Any, cast @@ -29,7 +32,6 @@ from superset.daos.database import DatabaseDAO from superset.exceptions import SupersetException from superset.extensions import db, security_manager from superset.models.core import Database -from superset.utils.core import DatasourceName logger = logging.getLogger(__name__) @@ -37,8 +39,15 @@ logger = logging.getLogger(__name__) class TablesDatabaseCommand(BaseCommand): _model: Database - def __init__(self, db_id: int, schema_name: str, force: bool): + def __init__( + self, + db_id: int, + catalog_name: str | None, + schema_name: str, + force: bool, + ): self._db_id = db_id + self._catalog_name = catalog_name self._schema_name = schema_name self._force = force @@ -47,11 +56,11 @@ class TablesDatabaseCommand(BaseCommand): try: tables = security_manager.get_datasources_accessible_by_user( database=self._model, + catalog=self._catalog_name, schema=self._schema_name, datasource_names=sorted( - DatasourceName(*datasource_name) - for datasource_name in self._model.get_all_table_names_in_schema( - catalog=None, + self._model.get_all_table_names_in_schema( + catalog=self._catalog_name, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, @@ -62,11 +71,11 @@ class TablesDatabaseCommand(BaseCommand): views = security_manager.get_datasources_accessible_by_user( database=self._model, + catalog=self._catalog_name, schema=self._schema_name, datasource_names=sorted( - DatasourceName(*datasource_name) - for datasource_name in self._model.get_all_view_names_in_schema( - catalog=None, + self._model.get_all_view_names_in_schema( + catalog=self._catalog_name, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, @@ -81,11 +90,15 @@ class TablesDatabaseCommand(BaseCommand): db.session.query(SqlaTable) .filter( SqlaTable.database_id == self._model.id, + SqlaTable.catalog == self._catalog_name, SqlaTable.schema == self._schema_name, ) .options( load_only( - SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra + SqlaTable.catalog, + SqlaTable.schema, + SqlaTable.table_name, + SqlaTable.extra, ), lazyload(SqlaTable.columns), lazyload(SqlaTable.metrics), diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 5e0968954c..5085b376e3 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -50,12 +50,12 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): - _model: Optional[Database] + _model: Database | None def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id - self._model: Optional[Database] = None + self._model: Database | None = None def run(self) -> Model: self._model = DatabaseDAO.find_by_id(self._model_id) @@ -85,7 +85,7 @@ class UpdateDatabaseCommand(BaseCommand): ) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) - self._refresh_schemas(database, original_database_name, ssh_tunnel) + self._refresh_catalogs(database, original_database_name, ssh_tunnel) except SSHTunnelError as ex: # allow exception to bubble for debugbing information raise ex @@ -121,62 +121,192 @@ class UpdateDatabaseCommand(BaseCommand): ssh_tunnel_properties, ).run() - def _refresh_schemas( + def _get_catalog_names( self, database: Database, - original_database_name: str, - ssh_tunnel: Optional[SSHTunnel], - ) -> None: + ssh_tunnel: SSHTunnel | None, + ) -> set[str]: """ - Add permissions for any new schemas. + Helper method to load catalogs. + + This method captures a generic exception, since errors could potentially come + from any of the 50+ database drivers we support. """ try: - schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) + return database.get_all_catalog_names( + force=True, + ssh_tunnel=ssh_tunnel, + ) except Exception as ex: db.session.rollback() raise DatabaseConnectionFailedError() from ex + def _get_schema_names( + self, + database: Database, + catalog: str | None, + ssh_tunnel: SSHTunnel | None, + ) -> set[str]: + """ + Helper method to load schemas. + + This method captures a generic exception, since errors could potentially come + from any of the 50+ database drivers we support. + """ + try: + return database.get_all_schema_names( + force=True, + catalog=catalog, + ssh_tunnel=ssh_tunnel, + ) + except Exception as ex: + db.session.rollback() + raise DatabaseConnectionFailedError() from ex + + def _refresh_catalogs( + self, + database: Database, + original_database_name: str, + ssh_tunnel: SSHTunnel | None, + ) -> None: + """ + Add permissions for any new catalogs and schemas. + """ + catalogs = ( + self._get_catalog_names(database, ssh_tunnel) + if database.db_engine_spec.supports_catalog + else [None] + ) + + for catalog in catalogs: + schemas = self._get_schema_names(database, catalog, ssh_tunnel) + + perm = security_manager.get_catalog_perm( + original_database_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if not existing_pvm: + # new catalog + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database.database_name, catalog), + ) + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) + continue + + # add possible new schemas in catalog + self._refresh_schemas( + database, + original_database_name, + catalog, + schemas, + ) + + if original_database_name != database.database_name: + self._rename_existing_permissions( + database, + original_database_name, + catalog, + schemas, + ) + + db.session.commit() + + def _refresh_schemas( + self, + database: Database, + original_database_name: str, + catalog: str | None, + schemas: set[str], + ) -> None: for schema in schemas: - original_vm = security_manager.get_schema_perm( + perm = security_manager.get_schema_perm( original_database_name, + catalog, schema, ) existing_pvm = security_manager.find_permission_view_menu( "schema_access", - original_vm, + perm, ) if not existing_pvm: # new schema security_manager.add_permission_view_menu( "schema_access", - security_manager.get_schema_perm(database.database_name, schema), + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), ) - continue - if original_database_name == database.database_name: - continue + def _rename_existing_permissions( + self, + database: Database, + original_database_name: str, + catalog: str | None, + schemas: set[str], + ) -> None: + new_name = security_manager.get_catalog_perm( + database.database_name, + catalog, + ) - # rename existing schema permission - existing_pvm.view_menu.name = security_manager.get_schema_perm( + # rename existing catalog permission + perm = security_manager.get_catalog_perm( + original_database_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + existing_pvm.view_menu.name = new_name + + for schema in schemas: + new_name = security_manager.get_schema_perm( database.database_name, + catalog, schema, ) + # rename existing schema permission + perm = security_manager.get_schema_perm( + original_database_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + existing_pvm.view_menu.name = new_name + # rename permissions on datasets and charts for dataset in DatabaseDAO.get_datasets( database.id, - catalog=None, + catalog=catalog, schema=schema, ): - dataset.schema_perm = existing_pvm.view_menu.name + dataset.schema_perm = new_name for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: - chart.schema_perm = existing_pvm.view_menu.name - - db.session.commit() + chart.schema_perm = new_name def validate(self) -> None: exceptions: list[ValidationError] = [] - database_name: Optional[str] = self._properties.get("database_name") + database_name: str | None = self._properties.get("database_name") if database_name: # Check database_name uniqueness if not DatabaseDAO.validate_update_uniqueness( diff --git a/superset/config.py b/superset/config.py index 9388edbe84..7851938b72 100644 --- a/superset/config.py +++ b/superset/config.py @@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None # # Takes as a parameter the common bootstrap payload before transformations. # Returns a dict containing data that should be added or overridden to the payload. -COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731 - lambda data: {} -) # default: empty dict +COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # noqa: E731 + [dict[str, Any]], dict[str, Any] +] = lambda data: {} # EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes # example code for "My custom warm to hot" color scheme diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 719d5af588..330e4dfe30 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -211,6 +211,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= params = Column(String(1000)) perm = Column(String(1000)) schema_perm = Column(String(1000)) + catalog_perm = Column(String(1000)) is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) @@ -1261,9 +1262,20 @@ class SqlaTable( anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>' return Markup(anchor) + def get_catalog_perm(self) -> str | None: + """Returns catalog permission if present, database one otherwise.""" + return security_manager.get_catalog_perm( + self.database.database_name, + self.catalog, + ) + def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema or None) + return security_manager.get_schema_perm( + self.database.database_name, + self.catalog, + self.schema or None, + ) def get_perm(self) -> str: """ @@ -1282,7 +1294,10 @@ class SqlaTable( @property def full_name(self) -> str: return utils.get_datasource_full_name( - self.database, self.table_name, schema=self.schema + self.database, + self.table_name, + catalog=self.catalog, + schema=self.schema, ) @property @@ -1870,6 +1885,7 @@ class SqlaTable( cls, database: Database, datasource_name: str, + catalog: str | None = None, schema: str | None = None, ) -> list[SqlaTable]: query = ( @@ -1877,6 +1893,8 @@ class SqlaTable( .filter_by(database_id=database.id) .filter_by(table_name=datasource_name) ) + if catalog: + query = query.filter_by(catalog=catalog) if schema: query = query.filter_by(schema=schema) return query.all() @@ -1886,9 +1904,9 @@ class SqlaTable( cls, database: Database, permissions: set[str], + catalog_perms: set[str], schema_perms: set[str], ) -> list[SqlaTable]: - # TODO(hughhhh): add unit test return ( db.session.query(cls) .filter_by(database_id=database.id) @@ -1896,6 +1914,7 @@ class SqlaTable( or_( SqlaTable.perm.in_(permissions), SqlaTable.schema_perm.in_(schema_perms), + SqlaTable.catalog_perm.in_(catalog_perms), ) ) .all() diff --git a/superset/constants.py b/superset/constants.py index 28902ded6c..8e1563c9d3 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = { "related_objects": "read", "tables": "read", "schemas": "read", + "catalogs": "read", "select_star": "read", "table_metadata": "read", "table_metadata_deprecated": "read", diff --git a/superset/databases/api.py b/superset/databases/api.py index a77019123b..955babdb82 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,7 +72,9 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO from superset.databases.decorators import check_table_access from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter from superset.databases.schemas import ( + CatalogsResponseSchema, CSVUploadPostSchema, + database_catalogs_query_schema, database_schemas_query_schema, database_tables_query_schema, DatabaseConnectionSchema, @@ -140,6 +142,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "table_extra_metadata", "table_extra_metadata_deprecated", "select_star", + "catalogs", "schemas", "test_connection", "related_objects", @@ -256,6 +259,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): edit_model_schema = DatabasePutSchema() apispec_parameter_schemas = { + "database_catalogs_query_schema": database_catalogs_query_schema, "database_schemas_query_schema": database_schemas_query_schema, "database_tables_query_schema": database_tables_query_schema, "get_export_ids_schema": get_export_ids_schema, @@ -263,6 +267,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): openapi_spec_tag = "Database" openapi_spec_component_schemas = ( + CatalogsResponseSchema, CSVUploadPostSchema, DatabaseConnectionSchema, DatabaseFunctionNamesResponse, @@ -589,6 +594,69 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ) return self.response_422(message=str(ex)) + @expose("/<int:pk>/catalogs/") + @protect() + @safe + @rison(database_catalogs_query_schema) + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs", + log_to_statsd=False, + ) + def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse: + """Get all catalogs from a database. + --- + get: + summary: Get all catalogs from a database + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/database_catalogs_query_schema' + responses: + 200: + description: A List of all catalogs from the database + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogsResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + database = self.datamodel.get(pk, self._base_filters) + if not database: + return self.response_404() + try: + catalogs = database.get_all_catalog_names( + cache=database.catalog_cache_enabled, + cache_timeout=database.catalog_cache_timeout or None, + force=kwargs["rison"].get("force", False), + ) + catalogs = security_manager.get_catalogs_accessible_by_user( + database, + catalogs, + ) + return self.response(200, result=list(catalogs)) + except OperationalError: + return self.response( + 500, message="There was an error connecting to the database" + ) + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("/<int:pk>/schemas/") @protect() @safe @@ -635,13 +703,19 @@ class DatabaseRestApi(BaseSupersetModelRestApi): if not database: return self.response_404() try: + catalog = kwargs["rison"].get("catalog") schemas = database.get_all_schema_names( + catalog=catalog, cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout or None, force=kwargs["rison"].get("force", False), ) - schemas = security_manager.get_schemas_accessible_by_user(database, schemas) - return self.response(200, result=schemas) + schemas = security_manager.get_schemas_accessible_by_user( + database, + catalog, + schemas, + ) + return self.response(200, result=list(schemas)) except OperationalError: return self.response( 500, message="There was an error connecting to the database" @@ -703,10 +777,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ force = kwargs["rison"].get("force", False) + catalog_name = kwargs["rison"].get("catalog_name") schema_name = kwargs["rison"].get("schema_name", "") try: - command = TablesDatabaseCommand(pk, schema_name, force) + command = TablesDatabaseCommand(pk, catalog_name, schema_name, force) payload = command.run() return self.response(200, **payload) except DatabaseNotFoundError: diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 33748da4b6..625ceb7bd5 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -28,11 +28,9 @@ from superset.models.core import Database from superset.views.base import BaseFilter -def can_access_databases( - view_menu_name: str, -) -> set[str]: +def can_access_databases(view_menu_name: str) -> set[str]: return { - security_manager.unpack_database_and_schema(vm).database + security_manager.unpack_database_catalog_schema(vm).database for vm in security_manager.user_view_menu_names(view_menu_name) } @@ -58,6 +56,7 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods return query database_perms = security_manager.user_view_menu_names("database_access") schema_access_databases = can_access_databases("schema_access") + catalog_access_databases = can_access_databases("catalog_access") datasource_access_databases = can_access_databases("datasource_access") @@ -65,7 +64,11 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods or_( self.model.perm.in_(database_perms), self.model.database_name.in_( - [*schema_access_databases, *datasource_access_databases] + [ + *catalog_access_databases, + *schema_access_databases, + *datasource_access_databases, + ] ), ) ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 1bc0af7472..486a258a9a 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri from superset.utils.core import markdown, parse_ssl_cert database_schemas_query_schema = { + "type": "object", + "properties": { + "force": {"type": "boolean"}, + "catalog": {"type": "string"}, + }, +} + +database_catalogs_query_schema = { "type": "object", "properties": {"force": {"type": "boolean"}}, } @@ -65,6 +73,7 @@ database_tables_query_schema = { "properties": { "force": {"type": "boolean"}, "schema_name": {"type": "string"}, + "catalog_name": {"type": "string"}, }, "required": ["schema_name"], } @@ -712,6 +721,12 @@ class SchemasResponseSchema(Schema): ) +class CatalogsResponseSchema(Schema): + result = fields.List( + fields.String(metadata={"description": "A database catalog name"}) + ) + + class DatabaseTablesResponse(Schema): extra = fields.Dict( metadata={"description": "Extra data used to specify column metadata"} diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3cc1315129..0a4955a19f 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -638,10 +638,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return driver in cls.drivers + @classmethod + def get_default_catalog( + cls, + database: Database, # pylint: disable=unused-argument + ) -> str | None: + """ + Return the default catalog for a given database. + """ + return None + @classmethod def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ - Return the default schema in a given database. + Return the default schema for a catalog in a given database. """ with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @@ -1412,24 +1422,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs from database. This needs to be implemented per database, since SQLAlchemy doesn't offer an abstraction. """ - return [] + return set() @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: + def get_schema_names(cls, inspector: Inspector) -> set[str]: """ Get all schemas from database :param inspector: SqlAlchemy inspector :return: All schemas in the database """ - return sorted(inspector.get_schema_names()) + return set(inspector.get_schema_names()) @classmethod def get_table_names( # pylint: disable=unused-argument diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 8a2612f5b0..8e508b0e0f 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. @@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met client = cls._get_client(engine) projects = client.list_projects() - return sorted(project.project_id for project in projects) + return {project.project_id for project in projects} @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index 4346f77d64..d2dc7307d9 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec): @classmethod def get_function_names(cls, database: Database) -> list[str]: - # pylint: disable=import-outside-toplevel,import-error + # pylint: disable=import-outside-toplevel from clickhouse_connect.driver.exceptions import ClickHouseError if cls._function_names: @@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec): def validate_parameters( cls, properties: BasicPropertiesType ) -> list[SupersetError]: - # pylint: disable=import-outside-toplevel,import-error + # pylint: disable=import-outside-toplevel from clickhouse_connect.driver import default_port parameters = properties.get("parameters", {}) diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 1d3ec4e9e5..d7d1862aaf 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec): return None @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: - schemas = [ + def get_schema_names(cls, inspector: Inspector) -> set[str]: + return { row[0] for row in inspector.engine.execute("SHOW SCHEMAS") if not row[0].startswith("_") - ] - return schemas + } @classmethod def has_implicit_cancel(cls) -> bool: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index ce87aa1f9b..bba2157e0a 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec): engine = "" engine_name = "PostgreSQL" - supports_catalog = True - _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('second', {col})", @@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): engine = "postgresql" engine_aliases = {"postgres"} + supports_dynamic_schema = True + supports_catalog = True + supports_dynamic_catalog = True default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( @@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return super().get_default_schema_for_query(database, query) + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + """ + Set the catalog (database). + """ + if catalog: + uri = uri.set(database=catalog) + + return uri, connect_args + + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + """ + Return the default catalog for a given database. + """ + return database.url_object.database + @classmethod def get_prequeries( cls, @@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Postgres, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( """ @@ -360,7 +384,7 @@ SELECT datname FROM pg_database WHERE datistemplate = false; """ ) - ) + } @classmethod def get_table_names( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 34c47eb522..5a2c3afa5a 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False + supports_catalog = True + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( @@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. """ - return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")] + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} @classmethod def _create_column_info( diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 83d382cda1..0f03de2188 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, database: "Database", inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Snowflake, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( "SELECT DATABASE_NAME from information_schema.databases" ) - ) + } @classmethod def epoch_to_dttm(cls) -> str: diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 2d8444cc99..fd697aea82 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter): self.schema = parts.pop(-1) if parts else None self.catalog = parts.pop(-1) if parts else None - if self.catalog: - # TODO (betodealmeida): when SIP-95 is implemented we should check to see if - # the database has multi-catalog enabled, and if so, give access. - raise NotImplementedError("Catalogs are not currently supported") - # If the table has a single integer primary key we use that as the row ID in order # to perform updates and deletes. Otherwise we can only do inserts and selects. self._rowid: str | None = None diff --git a/superset/utils/filters.py b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py similarity index 51% copy from superset/utils/filters.py copy to superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py index 88154a40b3..ca04eac2d1 100644 --- a/superset/utils/filters.py +++ b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py @@ -14,28 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any - -from flask_appbuilder import Model -from sqlalchemy import or_ -from sqlalchemy.sql.elements import BooleanClauseList - - -def get_dataset_access_filters( - base_model: type[Model], - *args: Any, -) -> BooleanClauseList: - # pylint: disable=import-outside-toplevel - from superset import security_manager - from superset.connectors.sqla.models import Database - - database_ids = security_manager.get_accessible_databases() - perms = security_manager.user_view_menu_names("datasource_access") - schema_perms = security_manager.user_view_menu_names("schema_access") - - return or_( - Database.id.in_(database_ids), - base_model.perm.in_(perms), - base_model.schema_perm.in_(schema_perms), - *args, +"""Add catalog_perm to tables + +Revision ID: 58d051681a3b +Revises: 5f57af97bc3f +Create Date: 2024-05-01 10:52:31.458433 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "58d051681a3b" +down_revision = "5f57af97bc3f" + + +def upgrade(): + op.add_column( + "tables", + sa.Column("catalog_perm", sa.String(length=1000), nullable=True), ) + + +def downgrade(): + op.drop_column("tables", "catalog_perm") diff --git a/superset/models/core.py b/superset/models/core.py index 9a4a1de403..50428f4de9 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1,4 +1,4 @@ -# Licensed to the Apache Software Foundation (ASF) under one +# get_all_table_names_in_schema Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -78,7 +78,7 @@ from superset.sql_parse import Table from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum -from superset.utils.core import get_username +from superset.utils.core import DatasourceName, get_username from superset.utils.oauth2 import get_oauth2_access_token config = app.config @@ -313,6 +313,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def metadata_cache_timeout(self) -> dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) + @property + def catalog_cache_enabled(self) -> bool: + return "catalog_cache_timeout" in self.metadata_cache_timeout + + @property + def catalog_cache_timeout(self) -> int | None: + return self.metadata_cache_timeout.get("catalog_cache_timeout") + @property def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @@ -549,6 +557,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable yield conn + def get_default_catalog(self) -> str | None: + """ + Return the default configured catalog for the database. + """ + return self.db_engine_spec.get_default_catalog(self) + + def get_default_schema(self, catalog: str | None) -> str | None: + """ + Return the default schema for the database. + """ + return self.db_engine_spec.get_default_schema(self, catalog) + def get_default_schema_for_query(self, query: Query) -> str | None: """ Return the default schema for a given query. @@ -713,12 +733,13 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable cache: bool = False, cache_timeout: int | None = None, force: bool = False, - ) -> set[tuple[str, str]]: + ) -> set[DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. + :param catalog: optional catalog name :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache @@ -728,7 +749,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable try: with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { - (table, schema) + DatasourceName(table, schema, catalog) for table in self.db_engine_spec.get_table_names( database=self, inspector=inspector, @@ -792,22 +813,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable key="db:{self.id}:schema_list", cache=cache_manager.cache, ) - def get_all_schema_names( # pylint: disable=unused-argument + def get_all_schema_names( self, + *, catalog: str | None = None, - cache: bool = False, - cache_timeout: int | None = None, - force: bool = False, ssh_tunnel: SSHTunnel | None = None, - ) -> list[str]: - """Parameters need to be passed as keyword arguments. - - For unused parameters, they are referenced in - cache_util.memoized_func decorator. + ) -> set[str]: + """ + Return the schemas in a given database - :param cache: whether cache is enabled for the function - :param cache_timeout: timeout in seconds for the cache - :param force: whether to force refresh the cache + :param catalog: override default catalog + :param ssh_tunnel: SSH tunnel information needed to establish a connection :return: schema list """ try: @@ -819,6 +835,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @cache_util.memoized_func( + key="db:{self.id}:catalog_list", + cache=cache_manager.cache, + ) + def get_all_catalog_names( + self, + *, + ssh_tunnel: SSHTunnel | None = None, + ) -> set[str]: + """ + Return the catalogs in a given database + + :param ssh_tunnel: SSH tunnel information needed to establish a connection + :return: catalog list + """ + try: + with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector: + return self.db_engine_spec.get_catalog_names(self, inspector) + except Exception as ex: + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @property def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]: url = make_url_safe(self.sqlalchemy_uri_decrypted) diff --git a/superset/security/manager.py b/superset/security/manager.py index a84c0cec0d..5b2ad50e3e 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -22,7 +22,7 @@ import logging import re import time from collections import defaultdict -from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING from flask import current_app, Flask, g, Request from flask_appbuilder import Model @@ -67,7 +67,7 @@ from superset.security.guest_token import ( GuestTokenUser, GuestUser, ) -from superset.sql_parse import extract_tables_from_jinja_sql +from superset.sql_parse import extract_tables_from_jinja_sql, Table from superset.superset_typing import Metric from superset.utils.core import ( DatasourceName, @@ -89,7 +89,6 @@ if TYPE_CHECKING: from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query - from superset.sql_parse import Table from superset.viz import BaseViz logger = logging.getLogger(__name__) @@ -97,8 +96,9 @@ logger = logging.getLogger(__name__) DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P<id>\d+)\)$") -class DatabaseAndSchema(NamedTuple): +class DatabaseCatalogSchema(NamedTuple): database: str + catalog: Optional[str] schema: str @@ -346,17 +346,55 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return self.get_guest_user_from_request(request) return None + def get_catalog_perm( + self, + database: str, + catalog: Optional[str] = None, + ) -> Optional[str]: + """ + Return the database specific catalog permission. + + :param database: The Superset database or database name + :param catalog: The database catalog name + :return: The database specific schema permission + """ + if catalog is None: + return None + + return f"[{database}].[{catalog}]" + def get_schema_perm( - self, database: Union["Database", str], schema: Optional[str] = None + self, + database: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> Optional[str]: """ Return the database specific schema permission. - :param database: The Superset database or database name - :param schema: The Superset schema name + Catalogs were added in SIP-95, and not all databases support them. Because of + this, the format used for permissions is different depending on whether a + catalog is passed or not: + + [database].[schema] + [database].[catalog].[schema] + + For backwards compatibility, when processing the first format Superset should + use the default catalog when the database supports them. This way, migrating + existing permissions is not necessary. + + :param database: The database name + :param catalog: The database catalog name + :param schema: The database schema name :return: The database specific schema permission """ - return f"[{database}].[{schema}]" if schema else None + if schema is None: + return None + + if catalog: + return f"[{database}].[{catalog}].[{schema}]" + + return f"[{database}].[{schema}]" @staticmethod def get_database_perm(database_id: int, database_name: str) -> Optional[str]: @@ -370,12 +408,22 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) -> Optional[str]: return f"[{database_name}].[{dataset_name}](id:{dataset_id})" - def unpack_database_and_schema(self, schema_permission: str) -> DatabaseAndSchema: - # [database_name].[schema|table] + def unpack_database_catalog_schema( + self, + schema_permission: str, + ) -> DatabaseCatalogSchema: + """ + Split permission into database/catalog/schema. + """ + parts = [part[1:-1] for part in schema_permission.split(".")] + if not 2 <= len(parts) <= 3: + raise ValueError("Invalid schema permission format") + + database = parts[0] + schema = parts[-1] + catalog = parts[1] if len(parts) == 3 else None - schema_name = schema_permission.split(".")[1][1:-1] - database_name = schema_permission.split(".")[0][1:-1] - return DatabaseAndSchema(database_name, schema_name) + return DatabaseCatalogSchema(database, catalog, schema) def can_access(self, permission_name: str, view_name: str) -> bool: """ @@ -435,6 +483,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods or self.can_access("database_access", database.perm) # type: ignore ) + def can_access_catalog(self, database: "Database", catalog: str) -> bool: + """ + Return if the user can access the specified catalog. + """ + return ( + self.can_access_all_datasources() + or self.can_access_database(database) + or self.can_access("catalog_access", f"[{database}].[{catalog}]") + ) + def can_access_schema(self, datasource: "BaseDatasource") -> bool: """ Return True if the user can access the schema associated with specified @@ -447,6 +505,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return ( self.can_access_all_datasources() or self.can_access_database(datasource.database) + or self.can_access_catalog(datasource.database, datasource.catalog) or self.can_access("schema_access", datasource.schema_perm or "") ) @@ -705,55 +764,150 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ] def get_schemas_accessible_by_user( - self, database: "Database", schemas: list[str], hierarchical: bool = True - ) -> list[str]: + self, + database: "Database", + catalog: Optional[str], + schemas: set[str], + hierarchical: bool = True, + ) -> set[str]: """ - Return the list of SQL schemas accessible by the user. + Returned a filtered list of the schemas accessible by the user. + + If not catalog is specified, the default catalog is used. :param database: The SQL database - :param schemas: The list of eligible SQL schemas + :param catalog: An optional database catalog + :param schemas: A set of candidate schemas :param hierarchical: Whether to check using the hierarchical permission logic - :returns: The list of accessible SQL schemas + :returns: The set of accessible database schemas """ # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import SqlaTable - if hierarchical and self.can_access_database(database): + if hierarchical and ( + self.can_access_database(database) + or (catalog and self.can_access_catalog(database, catalog)) + ): return schemas # schema_access - accessible_schemas = { - self.unpack_database_and_schema(s).schema - for s in self.user_view_menu_names("schema_access") - if s.startswith(f"[{database}].") - } + accessible_schemas: set[str] = set() + schema_access = self.user_view_menu_names("schema_access") + default_catalog = database.get_default_catalog() + default_schema = database.get_default_schema(default_catalog) + + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + + # [database].[schema] matches when no catalog is specified, or when the user + # specifies the default catalog + if len(parts) == 2 and (catalog is None or catalog == default_catalog): + accessible_schemas.add(parts[1]) + + # [database].[catalog].[schema] matches when the catalog is equal to the + # requested catalog or, when no catalog specified, it's equal to the default + # catalog. + elif len(parts) == 3 and parts[1] == (catalog or default_catalog): + accessible_schemas.add(parts[2]) # datasource_access if perms := self.user_view_menu_names("datasource_access"): tables = ( self.get_session.query(SqlaTable.schema) .filter(SqlaTable.database_id == database.id) - .filter(SqlaTable.schema.isnot(None)) - .filter(SqlaTable.schema != "") .filter(or_(SqlaTable.perm.in_(perms))) .distinct() ) - accessible_schemas.update([table.schema for table in tables]) + accessible_schemas.update( + { + table.schema or default_schema # type: ignore + for table in tables + if (table.schema or default_schema) + } + ) + + return schemas & accessible_schemas + + def get_catalogs_accessible_by_user( + self, + database: "Database", + catalogs: set[str], + hierarchical: bool = True, + ) -> set[str]: + """ + Returned a filtered list of the catalogs accessible by the user. + + :param database: The SQL database + :param catalogs: A set of candidate catalogs + :param hierarchical: Whether to check using the hierarchical permission logic + :returns: The set of accessible database catalogs + """ + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable + + if hierarchical and self.can_access_database(database): + return catalogs - return [s for s in schemas if s in accessible_schemas] + # catalog access + accessible_catalogs: set[str] = set() + catalog_access = self.user_view_menu_names("catalog_access") + default_catalog = database.get_default_catalog() + + for perm in catalog_access: + parts = [part[1:-1] for part in perm.split(".")] + if parts[0] == database.database_name: + accessible_catalogs.add(parts[1]) + + # schema access + schema_access = self.user_view_menu_names("schema_access") + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + if len(parts) == 2 and default_catalog: + accessible_catalogs.add(default_catalog) + elif len(parts) == 3: + accessible_catalogs.add(parts[2]) + + # datasource_access + if perms := self.user_view_menu_names("datasource_access"): + tables = ( + self.get_session.query(SqlaTable.schema) + .filter(SqlaTable.database_id == database.id) + .filter(or_(SqlaTable.perm.in_(perms))) + .distinct() + ) + accessible_catalogs.update( + { + table.catalog or default_catalog # type: ignore + for table in tables + if (table.catalog or default_catalog) + } + ) + + return catalogs & accessible_catalogs def get_datasources_accessible_by_user( # pylint: disable=invalid-name self, database: "Database", datasource_names: list[DatasourceName], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> list[DatasourceName]: """ - Return the list of SQL tables accessible by the user. + Filter list of SQL tables to the ones accessible by the user. + + When catalog and/or schema are specified, it's assumed that all datasources in + `datasource_names` are in the given catalog/schema. :param database: The SQL database :param datasource_names: The list of eligible SQL tables w/ schema + :param catalog: The fallback SQL catalog if not present in the table name :param schema: The fallback SQL schema if not present in the table name :returns: The list of accessible SQL tables w/ schema """ @@ -763,22 +917,34 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if self.can_access_database(database): return datasource_names + if catalog: + catalog_perm = self.get_catalog_perm(database, catalog) + if catalog_perm and self.can_access("catalog_access", catalog_perm): + return datasource_names + if schema: - schema_perm = self.get_schema_perm(database, schema) + schema_perm = self.get_schema_perm(database, catalog, schema) if schema_perm and self.can_access("schema_access", schema_perm): return datasource_names user_perms = self.user_view_menu_names("datasource_access") + catalog_perms = self.user_view_menu_names("catalog_access") schema_perms = self.user_view_menu_names("schema_access") - user_datasources = SqlaTable.query_datasources_by_permissions( - database, user_perms, schema_perms - ) - if schema: - names = {d.table_name for d in user_datasources if d.schema == schema} - return [d for d in datasource_names if d.table in names] + user_datasources = { + DatasourceName(table.table_name, table.schema, table.catalog) + for table in SqlaTable.query_datasources_by_permissions( + database, + user_perms, + catalog_perms, + schema_perms, + ) + } - full_names = {d.full_name for d in user_datasources} - return [d for d in datasource_names if f"[{database}].[{d}]" in full_names] + return [ + datasource + for datasource in datasource_names + if datasource in user_datasources + ] def merge_perm(self, permission_name: str, view_menu_name: str) -> None: """ @@ -843,6 +1009,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.get_schema_perm()) + merge_pv("catalog_access", datasource.get_catalog_perm()) logger.info("Creating missing database permissions.") databases = self.get_session.query(models.Database).all() @@ -1211,7 +1378,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.get_session.query(self.permissionview_model) .join(self.permission_model) .join(self.viewmenu_model) - .filter(self.permission_model.name == "schema_access") + .filter( + or_( + self.permission_model.name == "schema_access", + self.permission_model.name == "catalog_access", + ) + ) .filter(self.viewmenu_model.name.like(f"[{database_name}].[%]")) .all() ) @@ -1398,18 +1570,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods .values(perm=dataset_perm) ) + # update catalog and schema perms + values: dict[str, Optional[str]] = {} + if target.schema: dataset_schema_perm = self.get_schema_perm( - database.database_name, target.schema + database.database_name, + target.catalog, + target.schema, ) self._insert_pvm_on_sqla_event( - mapper, connection, "schema_access", dataset_schema_perm + mapper, + connection, + "schema_access", + dataset_schema_perm, ) target.schema_perm = dataset_schema_perm + values["schema_perm"] = dataset_schema_perm + + if target.catalog: + dataset_catalog_perm = self.get_catalog_perm( + database.database_name, + target.catalog, + ) + self._insert_pvm_on_sqla_event( + mapper, + connection, + "catalog_access", + dataset_catalog_perm, + ) + target.catalog_perm = dataset_catalog_perm + values["catalog_perm"] = dataset_catalog_perm + + if values: connection.execute( dataset_table.update() .where(dataset_table.c.id == target.id) - .values(schema_perm=dataset_schema_perm) + .values(**values) ) def dataset_after_delete( @@ -1466,6 +1663,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods table.select().where(table.c.id == target.id) ).one() current_db_id = current_dataset.database_id + current_catalog = current_dataset.catalog current_schema = current_dataset.schema current_table_name = current_dataset.table_name @@ -1478,14 +1676,21 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods mapper, connection, target.perm, new_dataset_vm_name, target ) - # Updates schema permissions - new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + # Updates catalog/schema permissions + dataset_catalog_name = self.get_catalog_perm( + target.database.database_name, + target.catalog, + ) + dataset_schema_name = self.get_schema_perm( + target.database.database_name, + target.catalog, + target.schema, ) - self._update_dataset_schema_perm( + self._update_dataset_catalog_schema_perm( mapper, connection, - new_dataset_schema_name, + dataset_catalog_name, + dataset_schema_name, target, ) @@ -1501,23 +1706,32 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target ) - # When schema changes - if current_schema != target.schema: - new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + # When catalog/schema change + if current_catalog != target.catalog or current_schema != target.schema: + dataset_catalog_name = self.get_catalog_perm( + target.database.database_name, + target.catalog, + ) + dataset_schema_name = self.get_schema_perm( + target.database.database_name, + target.catalog, + target.schema, ) - self._update_dataset_schema_perm( + self._update_dataset_catalog_schema_perm( mapper, connection, - new_dataset_schema_name, + dataset_catalog_name, + dataset_schema_name, target, ) - def _update_dataset_schema_perm( + # pylint: disable=invalid-name, too-many-arguments + def _update_dataset_catalog_schema_perm( self, mapper: Mapper, connection: Connection, - new_schema_permission_name: Optional[str], + catalog_permission_name: Optional[str], + schema_permission_name: Optional[str], target: "SqlaTable", ) -> None: """ @@ -1529,11 +1743,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param mapper: The SQLA event mapper :param connection: The SQLA connection - :param new_schema_permission_name: The new schema permission name that changed + :param catalog_permission_name: The new catalog permission name that changed + :param schema_permission_name: The new schema permission name that changed :param target: Dataset that was updated :return: """ - logger.info("Updating schema perm, new: %s", new_schema_permission_name) from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel SqlaTable, ) @@ -1541,21 +1755,39 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods Slice, ) + logger.info( + "Updating catalog/schema permissions to: %s.%s", + catalog_permission_name, + schema_permission_name, + ) + sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member chart_table = Slice.__table__ # pylint: disable=no-member - # insert new schema PVM if it does not exist + # insert new PVMs if they don't not exist self._insert_pvm_on_sqla_event( - mapper, connection, "schema_access", new_schema_permission_name + mapper, + connection, + "catalog_access", + catalog_permission_name, + ) + self._insert_pvm_on_sqla_event( + mapper, + connection, + "schema_access", + schema_permission_name, ) - # Update dataset (SqlaTable schema_perm field) + # Update dataset connection.execute( sqlatable_table.update() .where( sqlatable_table.c.id == target.id, ) - .values(schema_perm=new_schema_permission_name) + .values( + catalog_perm=catalog_permission_name, + schema_perm=schema_permission_name, + ) ) # Update charts (Slice schema_perm field) @@ -1565,7 +1797,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods chart_table.c.datasource_id == target.id, chart_table.c.datasource_type == DatasourceType.TABLE, ) - .values(schema_perm=new_schema_permission_name) + .values( + catalog_perm=catalog_permission_name, + schema_perm=schema_permission_name, + ) ) def _update_dataset_perm( # pylint: disable=too-many-arguments @@ -1922,7 +2157,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, sql: Optional[str] = None, - catalog: Optional[str] = None, # pylint: disable=unused-argument + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> None: """ @@ -1946,7 +2181,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query - from superset.sql_parse import Table from superset.utils.core import shortid if sql and database: @@ -1954,6 +2188,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods database=database, sql=sql, schema=schema, + catalog=catalog, client_id=shortid()[:10], user_id=get_user_id(), ) @@ -1969,9 +2204,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return if query: + # Getting the default schema for a query is hard. Users can select the + # schema in SQL Lab, but there's no guarantee that the query actually + # will run in that schema. Each DB engine spec needs to implement the + # necessary logic to enforce that the query runs in the selected schema. + # If the DB engine spec doesn't implement the logic the schema is read + # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy + # inspector to read it. default_schema = database.get_default_schema_for_query(query) + # Determining the default catalog is much easier, because DB engine + # specs need explicit support for catalogs. + default_catalog = database.get_default_catalog() tables = { - Table(table_.table, table_.schema or default_schema) + Table( + table_.table, + table_.schema or default_schema, + table_.catalog or default_catalog, + ) for table_ in extract_tables_from_jinja_sql(query.sql, database) } elif table: @@ -1980,21 +2229,36 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods denied = set() for table_ in tables: - schema_perm = self.get_schema_perm(database, schema=table_.schema) - - if not (schema_perm and self.can_access("schema_access", schema_perm)): - datasources = SqlaTable.query_datasources_by_name( - database, table_.table, schema=table_.schema - ) + catalog_perm = self.get_catalog_perm( + database, + table_.catalog, + ) + if catalog_perm and self.can_access("catalog_access", catalog_perm): + continue - # Access to any datasource is suffice. - for datasource_ in datasources: - if self.can_access( - "datasource_access", datasource_.perm - ) or self.is_owner(datasource_): - break - else: - denied.add(table_) + schema_perm = self.get_schema_perm( + database, + table_.catalog, + table_.schema, + ) + if schema_perm and self.can_access("schema_access", schema_perm): + continue + + datasources = SqlaTable.query_datasources_by_name( + database, + table_.table, + schema=table_.schema, + catalog=table_.catalog, + ) + for datasource_ in datasources: + if self.can_access( + "datasource_access", + datasource_.perm, + ) or self.is_owner(datasource_): + # access to any datasource is sufficient + break + else: + denied.add(table_) if denied: raise SupersetSecurityException( diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 48e283e7c1..00216fc4b1 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def wrap(f: Callable[..., Any]) -> Callable[..., Any]: def wrapped_f(*args: Any, **kwargs: Any) -> Any: - if not kwargs.get("cache", True): + should_cache = kwargs.pop("cache", True) + force = kwargs.pop("force", False) + cache_timeout = kwargs.pop("cache_timeout", 0) + + if not should_cache: return f(*args, **kwargs) # format the key using args/kwargs passed to the decorated function @@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., cache_key = key.format(**bound_args.arguments) obj = cache.get(cache_key) - if not kwargs.get("force") and obj is not None: + if not force and obj is not None: return obj obj = f(*args, **kwargs) - cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0)) + cache.set(cache_key, obj, timeout=cache_timeout) return obj return wrapped_f diff --git a/superset/utils/core.py b/superset/utils/core.py index f02b004432..50aebdd972 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -679,11 +679,13 @@ def generic_find_uq_constraint_name( def get_datasource_full_name( - database_name: str, datasource_name: str, schema: str | None = None + database_name: str, + datasource_name: str, + catalog: str | None = None, + schema: str | None = None, ) -> str: - if not schema: - return f"[{database_name}].[{datasource_name}]" - return f"[{database_name}].[{schema}].[{datasource_name}]" + parts = [database_name, catalog, schema, datasource_name] + return ".".join([f"[{part}]" for part in parts if part]) def validate_json(obj: bytes | bytearray | str) -> None: @@ -1051,8 +1053,8 @@ def merge_extra_form_data(form_data: dict[str, Any]) -> None: "adhoc_filters", [] ) adhoc_filters.extend( - {"isExtra": True, **fltr} # type: ignore - for fltr in append_adhoc_filters + {"isExtra": True, **adhoc_filter} # type: ignore + for adhoc_filter in append_adhoc_filters ) if append_filters: for key, value in form_data.items(): @@ -1502,6 +1504,7 @@ def shortid() -> str: class DatasourceName(NamedTuple): table: str schema: str + catalog: str | None def get_stacktrace() -> str | None: diff --git a/superset/utils/filters.py b/superset/utils/filters.py index 88154a40b3..8c4a079949 100644 --- a/superset/utils/filters.py +++ b/superset/utils/filters.py @@ -32,10 +32,12 @@ def get_dataset_access_filters( database_ids = security_manager.get_accessible_databases() perms = security_manager.user_view_menu_names("datasource_access") schema_perms = security_manager.user_view_menu_names("schema_access") + catalog_perms = security_manager.user_view_menu_names("catalog_access") return or_( Database.id.in_(database_ids), base_model.perm.in_(perms), + base_model.catalog_perm.in_(catalog_perms), base_model.schema_perm.in_(schema_perms), *args, ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index c6e799e6d4..0d104aad5f 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -211,11 +211,29 @@ class DatabaseMixin: utils.parse_ssl_cert(database.server_cert) database.set_sqlalchemy_uri(database.sqlalchemy_uri) security_manager.add_permission_view_menu("database_access", database.perm) - # adding a new database we always want to force refresh schema list - for schema in database.get_all_schema_names(): - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) + + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names() + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database.database_name, catalog), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names(catalog=catalog): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) def pre_add(self, database: Database) -> None: self._pre_add_update(database)
