This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch new-dar
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 355432510471e85051960a71af653fabecb2512f
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Dec 17 16:19:42 2025 -0500

    Testing
---
 superset/data_access_rules/api.py     |  75 ++++++++++++++++++-
 superset/data_access_rules/schemas.py |  13 +++-
 superset/data_access_rules/utils.py   | 135 ++++++++++++++++++++++++++++++++++
 superset/databases/filters.py         |  18 ++++-
 superset/security/manager.py          |  45 ++++++++++++
 5 files changed, 282 insertions(+), 4 deletions(-)

diff --git a/superset/data_access_rules/api.py 
b/superset/data_access_rules/api.py
index f87932378a..c7c9e27f5b 100644
--- a/superset/data_access_rules/api.py
+++ b/superset/data_access_rules/api.py
@@ -23,9 +23,10 @@ including CRUD operations and a group_keys discovery 
endpoint.
 
 import logging
 
-from flask import Response
+from flask import request, Response
 from flask_appbuilder.api import expose, protect, safe
 from flask_appbuilder.models.sqla.interface import SQLAInterface
+from marshmallow import ValidationError
 
 from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
 from superset.data_access_rules.models import DataAccessRule
@@ -39,6 +40,7 @@ from superset.data_access_rules.utils import 
get_all_group_keys
 from superset.extensions import event_logger
 from superset.views.base_api import (
     BaseSupersetModelRestApi,
+    requires_json,
     statsd_metrics,
 )
 from superset.views.filters import BaseFilterRelatedRoles, 
BaseFilterRelatedUsers
@@ -118,6 +120,77 @@ class DataAccessRulesRestApi(BaseSupersetModelRestApi):
         "delete": {"delete": {"summary": "Delete a data access rule"}},
     }
 
+    @expose("/<int:pk>", methods=("PUT",))
+    @protect()
+    @safe
+    @statsd_metrics
+    @requires_json
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
+        log_to_statsd=False,
+    )
+    def put(self, pk: int) -> Response:
+        """Update a data access rule.
+        ---
+        put:
+          summary: Update a data access rule
+          parameters:
+          - in: path
+            schema:
+              type: integer
+            name: pk
+            description: The rule pk
+          requestBody:
+            description: Data access rule schema
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/DataAccessRulePutSchema'
+          responses:
+            200:
+              description: Rule updated
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      id:
+                        type: number
+                      result:
+                        $ref: '#/components/schemas/DataAccessRulePutSchema'
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            404:
+              $ref: '#/components/responses/404'
+            422:
+              $ref: '#/components/responses/422'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            item = self.edit_model_schema.load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+
+        # Get existing rule
+        existing = self.datamodel.get(pk)
+        if not existing:
+            return self.response_404()
+
+        # Update fields
+        for key, value in item.items():
+            setattr(existing, key, value)
+
+        try:
+            self.datamodel.edit(existing)
+            return self.response(200, id=existing.id, result=item)
+        except Exception as ex:
+            logger.error("Error updating data access rule: %s", str(ex), 
exc_info=True)
+            return self.response_422(message=str(ex))
+
     @expose("/group_keys/", methods=("GET",))
     @protect()
     @safe
diff --git a/superset/data_access_rules/schemas.py 
b/superset/data_access_rules/schemas.py
index 8a87bde8c5..a1f7f1137a 100644
--- a/superset/data_access_rules/schemas.py
+++ b/superset/data_access_rules/schemas.py
@@ -18,9 +18,10 @@
 Data Access Rules schemas for API serialization/deserialization.
 """
 
-from marshmallow import fields, Schema, validates_schema, ValidationError
+from marshmallow import fields, post_load, Schema, validates_schema, 
ValidationError
 
 from superset.dashboards.schemas import UserSchema
+from superset.data_access_rules.models import DataAccessRule
 
 # Field descriptions for OpenAPI documentation
 rule_description = """
@@ -68,9 +69,12 @@ class DataAccessRuleListSchema(Schema):
     role_id = fields.Integer(metadata={"description": "ID of the associated 
role"})
     role = fields.Nested(RoleSchema)
     rule = fields.String(metadata={"description": rule_description})
-    changed_on_delta_humanized = fields.String()
+    changed_on_delta_humanized = 
fields.Method("get_changed_on_delta_humanized")
     changed_by = fields.Nested(UserSchema(exclude=["username"]))
 
+    def get_changed_on_delta_humanized(self, obj: DataAccessRule) -> str:
+        return obj.changed_on_delta_humanized()
+
 
 class DataAccessRuleShowSchema(Schema):
     """Schema for showing a single data access rule."""
@@ -146,6 +150,11 @@ class DataAccessRulePostSchema(Schema):
             except json.JSONDecodeError as ex:
                 raise ValidationError(f"Invalid JSON: {ex}", 
field_name="rule") from ex
 
+    @post_load
+    def make_object(self, data: dict, **kwargs: dict) -> DataAccessRule:
+        """Convert validated data to a DataAccessRule instance."""
+        return DataAccessRule(**data)
+
 
 class DataAccessRulePutSchema(Schema):
     """Schema for updating a data access rule."""
diff --git a/superset/data_access_rules/utils.py 
b/superset/data_access_rules/utils.py
index 7ad55b6eff..31689c8b4f 100644
--- a/superset/data_access_rules/utils.py
+++ b/superset/data_access_rules/utils.py
@@ -560,6 +560,141 @@ def apply_data_access_rules(
         parsed_statement.apply_cls(cls_rules)
 
 
+def get_allowed_tables(
+    database_name: str,
+    schema: str | None = None,
+    catalog: str | None = None,
+) -> tuple[set[str], bool]:
+    """
+    Get all table names that the current user has access to via Data Access 
Rules
+    for a specific database and schema.
+
+    Args:
+        database_name: The database name to check
+        schema: Optional schema name to filter by
+        catalog: Optional catalog name to filter by
+
+    Returns:
+        Tuple of (set of table names, bool indicating if schema-level access 
is granted).
+        If schema-level access is granted, the set may be empty but all tables 
are allowed.
+    """
+    if not is_feature_enabled("DATA_ACCESS_RULES"):
+        return set(), False
+
+    rules = get_user_rules()
+    if not rules:
+        return set(), False
+
+    table_names: set[str] = set()
+    schema_level_access = False
+
+    for rule in rules:
+        rule_dict = rule.rule_dict
+
+        # Collect tables from allowed entries
+        for entry in rule_dict.get("allowed", []):
+            if entry.get("database") != database_name:
+                continue
+
+            # If catalog is specified in the entry, it must match
+            entry_catalog = entry.get("catalog")
+            if catalog is not None and entry_catalog is not None:
+                if entry_catalog != catalog:
+                    continue
+
+            # If schema is specified, check if it matches
+            entry_schema = entry.get("schema")
+            if schema is not None and entry_schema is not None:
+                if entry_schema != schema:
+                    continue
+
+            # If entry has a table, add it to the set
+            if table := entry.get("table"):
+                table_names.add(table)
+            elif entry_schema == schema or (entry_schema is None and schema is 
None):
+                # Schema-level or database-level access without table means 
all tables
+                schema_level_access = True
+
+    return table_names, schema_level_access
+
+
+def get_allowed_schemas(database_name: str, catalog: str | None = None) -> 
set[str]:
+    """
+    Get all schema names that the current user has access to via Data Access 
Rules
+    for a specific database.
+
+    Args:
+        database_name: The database name to check
+        catalog: Optional catalog name to filter by
+
+    Returns:
+        Set of schema names the user has access to.
+    """
+    if not is_feature_enabled("DATA_ACCESS_RULES"):
+        return set()
+
+    rules = get_user_rules()
+    if not rules:
+        return set()
+
+    schema_names: set[str] = set()
+
+    for rule in rules:
+        rule_dict = rule.rule_dict
+
+        # Collect schemas from allowed entries
+        for entry in rule_dict.get("allowed", []):
+            if entry.get("database") != database_name:
+                continue
+
+            # If catalog is specified in the entry, it must match
+            entry_catalog = entry.get("catalog")
+            if catalog is not None and entry_catalog is not None:
+                if entry_catalog != catalog:
+                    continue
+
+            # If the entry grants database-level access (no schema specified),
+            # we return an empty set to indicate "all schemas" should be 
allowed
+            # This will be handled by the caller
+            if schema := entry.get("schema"):
+                schema_names.add(schema)
+            elif entry.get("database") == database_name:
+                # Database-level access without schema means all schemas
+                # Return a special marker that caller can check
+                schema_names.add("*")
+
+    return schema_names
+
+
+def get_allowed_databases() -> set[str]:
+    """
+    Get all database names that the current user has access to via Data Access 
Rules.
+
+    This function is used to populate database selectors in SQL Lab and 
elsewhere.
+
+    Returns:
+        Set of database names the user has access to.
+    """
+    if not is_feature_enabled("DATA_ACCESS_RULES"):
+        return set()
+
+    rules = get_user_rules()
+    if not rules:
+        return set()
+
+    database_names: set[str] = set()
+
+    for rule in rules:
+        rule_dict = rule.rule_dict
+
+        # Collect databases from allowed entries
+        for entry in rule_dict.get("allowed", []):
+            if database := entry.get("database"):
+                database_names.add(database)
+
+    return database_names
+
+
 def get_all_group_keys(
     database_name: str | None = None,
     table: Table | None = None,
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 321eb62100..4f0920319e 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -23,11 +23,22 @@ from sqlalchemy.orm import Query
 from sqlalchemy.sql.expression import cast
 from sqlalchemy.sql.sqltypes import JSON
 
-from superset import security_manager
+from superset import is_feature_enabled, security_manager
 from superset.models.core import Database
 from superset.views.base import BaseFilter
 
 
+def get_dar_allowed_databases() -> set[str]:
+    """Get databases allowed by Data Access Rules for the current user."""
+    if not is_feature_enabled("DATA_ACCESS_RULES"):
+        return set()
+
+    # Lazy import to avoid circular dependency
+    from superset.data_access_rules.utils import get_allowed_databases
+
+    return get_allowed_databases()
+
+
 def can_access_databases(view_menu_name: str) -> set[str]:
     """
     Return names of databases available in `view_menu_name`.
@@ -62,10 +73,15 @@ class DatabaseFilter(BaseFilter):  # pylint: 
disable=too-few-public-methods
         catalog_access_databases = can_access_databases("catalog_access")
         schema_access_databases = can_access_databases("schema_access")
         datasource_access_databases = can_access_databases("datasource_access")
+
+        # Include databases from Data Access Rules
+        dar_databases = get_dar_allowed_databases()
+
         database_names = sorted(
             catalog_access_databases
             | schema_access_databases
             | datasource_access_databases
+            | dar_databases
         )
 
         return query.filter(
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 8490c049e3..f377832194 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -976,6 +976,19 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
                 }
             )
 
+        # Data Access Rules
+        # pylint: disable=import-outside-toplevel
+        from superset import is_feature_enabled
+
+        if is_feature_enabled("DATA_ACCESS_RULES"):
+            from superset.data_access_rules.utils import get_allowed_schemas
+
+            dar_schemas = get_allowed_schemas(database.database_name, catalog)
+            if "*" in dar_schemas:
+                # Database-level access means all schemas
+                return schemas
+            accessible_schemas.update(dar_schemas)
+
         return schemas & accessible_schemas
 
     def get_catalogs_accessible_by_user(
@@ -1091,6 +1104,24 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
             )
         }
 
+        # Check Data Access Rules
+        # pylint: disable=import-outside-toplevel
+        from superset import is_feature_enabled
+
+        if is_feature_enabled("DATA_ACCESS_RULES"):
+            from superset.data_access_rules.utils import get_allowed_tables
+
+            dar_tables, schema_level_access = get_allowed_tables(
+                database.database_name, schema, catalog
+            )
+            if schema_level_access:
+                # Schema-level access means all tables in the schema
+                return datasource_names
+
+            # Add DAR tables to accessible datasources
+            for table_name in dar_tables:
+                user_datasources.add(DatasourceName(table_name, schema, 
catalog))
+
         return [
             datasource
             for datasource in datasource_names
@@ -2410,6 +2441,20 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
                         # access to any datasource is sufficient
                         break
                 else:
+                    # Check Data Access Rules before denying
+                    if is_feature_enabled("DATA_ACCESS_RULES"):
+                        from superset.data_access_rules.utils import (
+                            AccessCheckResult,
+                            check_table_access,
+                        )
+
+                        access_info = check_table_access(
+                            database_name=database.database_name,
+                            table=table_,
+                        )
+                        if access_info.access == AccessCheckResult.ALLOWED:
+                            continue
+
                     denied.add(table_)
 
             if denied:

Reply via email to