This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch new-dar in repository https://gitbox.apache.org/repos/asf/superset.git
commit 355432510471e85051960a71af653fabecb2512f Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Dec 17 16:19:42 2025 -0500 Testing --- superset/data_access_rules/api.py | 75 ++++++++++++++++++- superset/data_access_rules/schemas.py | 13 +++- superset/data_access_rules/utils.py | 135 ++++++++++++++++++++++++++++++++++ superset/databases/filters.py | 18 ++++- superset/security/manager.py | 45 ++++++++++++ 5 files changed, 282 insertions(+), 4 deletions(-) diff --git a/superset/data_access_rules/api.py b/superset/data_access_rules/api.py index f87932378a..c7c9e27f5b 100644 --- a/superset/data_access_rules/api.py +++ b/superset/data_access_rules/api.py @@ -23,9 +23,10 @@ including CRUD operations and a group_keys discovery endpoint. import logging -from flask import Response +from flask import request, Response from flask_appbuilder.api import expose, protect, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import ValidationError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.data_access_rules.models import DataAccessRule @@ -39,6 +40,7 @@ from superset.data_access_rules.utils import get_all_group_keys from superset.extensions import event_logger from superset.views.base_api import ( BaseSupersetModelRestApi, + requires_json, statsd_metrics, ) from superset.views.filters import BaseFilterRelatedRoles, BaseFilterRelatedUsers @@ -118,6 +120,77 @@ class DataAccessRulesRestApi(BaseSupersetModelRestApi): "delete": {"delete": {"summary": "Delete a data access rule"}}, } + @expose("/<int:pk>", methods=("PUT",)) + @protect() + @safe + @statsd_metrics + @requires_json + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", + log_to_statsd=False, + ) + def put(self, pk: int) -> Response: + """Update a data access rule. + --- + put: + summary: Update a data access rule + parameters: + - in: path + schema: + type: integer + name: pk + description: The rule pk + requestBody: + description: Data access rule schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DataAccessRulePutSchema' + responses: + 200: + description: Rule updated + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/DataAccessRulePutSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + item = self.edit_model_schema.load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + + # Get existing rule + existing = self.datamodel.get(pk) + if not existing: + return self.response_404() + + # Update fields + for key, value in item.items(): + setattr(existing, key, value) + + try: + self.datamodel.edit(existing) + return self.response(200, id=existing.id, result=item) + except Exception as ex: + logger.error("Error updating data access rule: %s", str(ex), exc_info=True) + return self.response_422(message=str(ex)) + @expose("/group_keys/", methods=("GET",)) @protect() @safe diff --git a/superset/data_access_rules/schemas.py b/superset/data_access_rules/schemas.py index 8a87bde8c5..a1f7f1137a 100644 --- a/superset/data_access_rules/schemas.py +++ b/superset/data_access_rules/schemas.py @@ -18,9 +18,10 @@ Data Access Rules schemas for API serialization/deserialization. """ -from marshmallow import fields, Schema, validates_schema, ValidationError +from marshmallow import fields, post_load, Schema, validates_schema, ValidationError from superset.dashboards.schemas import UserSchema +from superset.data_access_rules.models import DataAccessRule # Field descriptions for OpenAPI documentation rule_description = """ @@ -68,9 +69,12 @@ class DataAccessRuleListSchema(Schema): role_id = fields.Integer(metadata={"description": "ID of the associated role"}) role = fields.Nested(RoleSchema) rule = fields.String(metadata={"description": rule_description}) - changed_on_delta_humanized = fields.String() + changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized") changed_by = fields.Nested(UserSchema(exclude=["username"])) + def get_changed_on_delta_humanized(self, obj: DataAccessRule) -> str: + return obj.changed_on_delta_humanized() + class DataAccessRuleShowSchema(Schema): """Schema for showing a single data access rule.""" @@ -146,6 +150,11 @@ class DataAccessRulePostSchema(Schema): except json.JSONDecodeError as ex: raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex + @post_load + def make_object(self, data: dict, **kwargs: dict) -> DataAccessRule: + """Convert validated data to a DataAccessRule instance.""" + return DataAccessRule(**data) + class DataAccessRulePutSchema(Schema): """Schema for updating a data access rule.""" diff --git a/superset/data_access_rules/utils.py b/superset/data_access_rules/utils.py index 7ad55b6eff..31689c8b4f 100644 --- a/superset/data_access_rules/utils.py +++ b/superset/data_access_rules/utils.py @@ -560,6 +560,141 @@ def apply_data_access_rules( parsed_statement.apply_cls(cls_rules) +def get_allowed_tables( + database_name: str, + schema: str | None = None, + catalog: str | None = None, +) -> tuple[set[str], bool]: + """ + Get all table names that the current user has access to via Data Access Rules + for a specific database and schema. + + Args: + database_name: The database name to check + schema: Optional schema name to filter by + catalog: Optional catalog name to filter by + + Returns: + Tuple of (set of table names, bool indicating if schema-level access is granted). + If schema-level access is granted, the set may be empty but all tables are allowed. + """ + if not is_feature_enabled("DATA_ACCESS_RULES"): + return set(), False + + rules = get_user_rules() + if not rules: + return set(), False + + table_names: set[str] = set() + schema_level_access = False + + for rule in rules: + rule_dict = rule.rule_dict + + # Collect tables from allowed entries + for entry in rule_dict.get("allowed", []): + if entry.get("database") != database_name: + continue + + # If catalog is specified in the entry, it must match + entry_catalog = entry.get("catalog") + if catalog is not None and entry_catalog is not None: + if entry_catalog != catalog: + continue + + # If schema is specified, check if it matches + entry_schema = entry.get("schema") + if schema is not None and entry_schema is not None: + if entry_schema != schema: + continue + + # If entry has a table, add it to the set + if table := entry.get("table"): + table_names.add(table) + elif entry_schema == schema or (entry_schema is None and schema is None): + # Schema-level or database-level access without table means all tables + schema_level_access = True + + return table_names, schema_level_access + + +def get_allowed_schemas(database_name: str, catalog: str | None = None) -> set[str]: + """ + Get all schema names that the current user has access to via Data Access Rules + for a specific database. + + Args: + database_name: The database name to check + catalog: Optional catalog name to filter by + + Returns: + Set of schema names the user has access to. + """ + if not is_feature_enabled("DATA_ACCESS_RULES"): + return set() + + rules = get_user_rules() + if not rules: + return set() + + schema_names: set[str] = set() + + for rule in rules: + rule_dict = rule.rule_dict + + # Collect schemas from allowed entries + for entry in rule_dict.get("allowed", []): + if entry.get("database") != database_name: + continue + + # If catalog is specified in the entry, it must match + entry_catalog = entry.get("catalog") + if catalog is not None and entry_catalog is not None: + if entry_catalog != catalog: + continue + + # If the entry grants database-level access (no schema specified), + # we return an empty set to indicate "all schemas" should be allowed + # This will be handled by the caller + if schema := entry.get("schema"): + schema_names.add(schema) + elif entry.get("database") == database_name: + # Database-level access without schema means all schemas + # Return a special marker that caller can check + schema_names.add("*") + + return schema_names + + +def get_allowed_databases() -> set[str]: + """ + Get all database names that the current user has access to via Data Access Rules. + + This function is used to populate database selectors in SQL Lab and elsewhere. + + Returns: + Set of database names the user has access to. + """ + if not is_feature_enabled("DATA_ACCESS_RULES"): + return set() + + rules = get_user_rules() + if not rules: + return set() + + database_names: set[str] = set() + + for rule in rules: + rule_dict = rule.rule_dict + + # Collect databases from allowed entries + for entry in rule_dict.get("allowed", []): + if database := entry.get("database"): + database_names.add(database) + + return database_names + + def get_all_group_keys( database_name: str | None = None, table: Table | None = None, diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 321eb62100..4f0920319e 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -23,11 +23,22 @@ from sqlalchemy.orm import Query from sqlalchemy.sql.expression import cast from sqlalchemy.sql.sqltypes import JSON -from superset import security_manager +from superset import is_feature_enabled, security_manager from superset.models.core import Database from superset.views.base import BaseFilter +def get_dar_allowed_databases() -> set[str]: + """Get databases allowed by Data Access Rules for the current user.""" + if not is_feature_enabled("DATA_ACCESS_RULES"): + return set() + + # Lazy import to avoid circular dependency + from superset.data_access_rules.utils import get_allowed_databases + + return get_allowed_databases() + + def can_access_databases(view_menu_name: str) -> set[str]: """ Return names of databases available in `view_menu_name`. @@ -62,10 +73,15 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods catalog_access_databases = can_access_databases("catalog_access") schema_access_databases = can_access_databases("schema_access") datasource_access_databases = can_access_databases("datasource_access") + + # Include databases from Data Access Rules + dar_databases = get_dar_allowed_databases() + database_names = sorted( catalog_access_databases | schema_access_databases | datasource_access_databases + | dar_databases ) return query.filter( diff --git a/superset/security/manager.py b/superset/security/manager.py index 8490c049e3..f377832194 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -976,6 +976,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods } ) + # Data Access Rules + # pylint: disable=import-outside-toplevel + from superset import is_feature_enabled + + if is_feature_enabled("DATA_ACCESS_RULES"): + from superset.data_access_rules.utils import get_allowed_schemas + + dar_schemas = get_allowed_schemas(database.database_name, catalog) + if "*" in dar_schemas: + # Database-level access means all schemas + return schemas + accessible_schemas.update(dar_schemas) + return schemas & accessible_schemas def get_catalogs_accessible_by_user( @@ -1091,6 +1104,24 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) } + # Check Data Access Rules + # pylint: disable=import-outside-toplevel + from superset import is_feature_enabled + + if is_feature_enabled("DATA_ACCESS_RULES"): + from superset.data_access_rules.utils import get_allowed_tables + + dar_tables, schema_level_access = get_allowed_tables( + database.database_name, schema, catalog + ) + if schema_level_access: + # Schema-level access means all tables in the schema + return datasource_names + + # Add DAR tables to accessible datasources + for table_name in dar_tables: + user_datasources.add(DatasourceName(table_name, schema, catalog)) + return [ datasource for datasource in datasource_names @@ -2410,6 +2441,20 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # access to any datasource is sufficient break else: + # Check Data Access Rules before denying + if is_feature_enabled("DATA_ACCESS_RULES"): + from superset.data_access_rules.utils import ( + AccessCheckResult, + check_table_access, + ) + + access_info = check_table_access( + database_name=database.database_name, + table=table_, + ) + if access_info.access == AccessCheckResult.ALLOWED: + continue + denied.add(table_) if denied:
