This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2e35854a05 Make the method `BaseAuthManager.is_authorized_custom_view` abstract (#37915) 2e35854a05 is described below commit 2e35854a052a13206cb1475973e039fbe394254c Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Fri Mar 15 09:57:22 2024 -0400 Make the method `BaseAuthManager.is_authorized_custom_view` abstract (#37915) --- airflow/api_connexion/security.py | 6 +- airflow/auth/managers/base_auth_manager.py | 17 +- .../amazon/aws/auth_manager/avp/entities.py | 2 + .../amazon/aws/auth_manager/aws_auth_manager.py | 176 ++------------------- .../amazon/aws/auth_manager/cli/schema.json | 32 ++++ .../api_endpoints/role_and_permission_endpoint.py | 12 +- .../auth_manager/api_endpoints/user_endpoint.py | 10 +- .../providers/fab/auth_manager/fab_auth_manager.py | 5 +- airflow/www/security_manager.py | 8 +- docs/apache-airflow/core-concepts/auth-manager.rst | 2 + newsfragments/37915.significant.rst | 1 + tests/auth/managers/test_base_auth_manager.py | 13 +- .../aws/auth_manager/test_aws_auth_manager.py | 164 +++++++++++++------ .../fab/auth_manager/test_fab_auth_manager.py | 35 ++++ tests/providers/fab/auth_manager/test_security.py | 3 +- 15 files changed, 241 insertions(+), 245 deletions(-) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index c99c05a5d3..1cc044d9dd 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -248,15 +248,15 @@ def requires_access_view(access_view: AccessView) -> Callable[[T], T]: def requires_access_custom_view( - fab_action_name: str, - fab_resource_name: str, + method: ResourceMethod, + resource_name: str, ) -> Callable[[T], T]: def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view( - fab_action_name=fab_action_name, fab_resource_name=fab_resource_name + method=method, resource_name=resource_name ), func=func, args=args, diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index e378e5e10b..4d5c249235 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -235,24 +235,21 @@ class BaseAuthManager(LoggingMixin): :param user: the user to perform the action on. If not provided (or None), it uses the current user """ + @abstractmethod def is_authorized_custom_view( - self, *, fab_action_name: str, fab_resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None ): """ Return whether the user is authorized to perform a given action on a custom view. - A custom view is a view defined as part of the auth manager. This view is then only available when - the auth manager is used as part of the environment. - - By default, it throws an exception because auth managers do not define custom views by default. - If an auth manager defines some custom views, it needs to override this method. + A custom view can be a view defined as part of the auth manager. This view is then only available when + the auth manager is used as part of the environment. It can also be a view defined as part of a + plugin defined by a user. - :param fab_action_name: the name of the FAB action defined in the view in ``base_permissions`` - :param fab_resource_name: the name of the FAB resource defined in the view in - ``class_permission_name`` + :param method: the method to perform + :param resource_name: the name of the resource :param user: the user to perform the action on. If not provided (or None), it uses the current user """ - raise AirflowException(f"The resource `{fab_resource_name}` does not exist in the environment.") def batch_is_authorized_connection( self, diff --git a/airflow/providers/amazon/aws/auth_manager/avp/entities.py b/airflow/providers/amazon/aws/auth_manager/avp/entities.py index 129b670c26..f2c6376729 100644 --- a/airflow/providers/amazon/aws/auth_manager/avp/entities.py +++ b/airflow/providers/amazon/aws/auth_manager/avp/entities.py @@ -35,8 +35,10 @@ class AvpEntities(Enum): # Resource types CONFIGURATION = "Configuration" CONNECTION = "Connection" + CUSTOM = "Custom" DAG = "Dag" DATASET = "Dataset" + MENU = "Menu" POOL = "Pool" VARIABLE = "Variable" VIEW = "View" diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index c17234c047..0dd3774246 100644 --- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -25,7 +25,7 @@ from flask import session, url_for from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities from airflow.providers.amazon.aws.auth_manager.avp.facade import ( AwsAuthManagerAmazonVerifiedPermissionsFacade, @@ -41,28 +41,6 @@ from airflow.providers.amazon.aws.auth_manager.constants import ( from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( AwsSecurityManagerOverride, ) -from airflow.security.permissions import ( - RESOURCE_AUDIT_LOG, - RESOURCE_CLUSTER_ACTIVITY, - RESOURCE_CONFIG, - RESOURCE_CONNECTION, - RESOURCE_DAG, - RESOURCE_DAG_CODE, - RESOURCE_DAG_DEPENDENCIES, - RESOURCE_DAG_RUN, - RESOURCE_DATASET, - RESOURCE_DOCS, - RESOURCE_JOB, - RESOURCE_PLUGIN, - RESOURCE_POOL, - RESOURCE_PROVIDER, - RESOURCE_SLA_MISS, - RESOURCE_TASK_INSTANCE, - RESOURCE_TASK_RESCHEDULE, - RESOURCE_TRIGGER, - RESOURCE_VARIABLE, - RESOURCE_XCOM, -) try: from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod @@ -97,136 +75,6 @@ if TYPE_CHECKING: from airflow.www.extensions.init_appbuilder import AirflowAppBuilder -_MENU_ITEM_REQUESTS: dict[str, IsAuthorizedRequest] = { - RESOURCE_AUDIT_LOG: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.AUDIT_LOG.value, - }, - }, - }, - RESOURCE_CLUSTER_ACTIVITY: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.CLUSTER_ACTIVITY.value, - }, - RESOURCE_CONFIG: { - "method": "GET", - "entity_type": AvpEntities.CONFIGURATION, - }, - RESOURCE_CONNECTION: { - "method": "GET", - "entity_type": AvpEntities.CONNECTION, - }, - RESOURCE_DAG: { - "method": "GET", - "entity_type": AvpEntities.DAG, - }, - RESOURCE_DAG_CODE: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.CODE.value, - }, - }, - }, - RESOURCE_DAG_DEPENDENCIES: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.DEPENDENCIES.value, - }, - }, - }, - RESOURCE_DAG_RUN: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.RUN.value, - }, - }, - }, - RESOURCE_DATASET: { - "method": "GET", - "entity_type": AvpEntities.DATASET, - }, - RESOURCE_DOCS: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.DOCS.value, - }, - RESOURCE_PLUGIN: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.PLUGINS.value, - }, - RESOURCE_JOB: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.JOBS.value, - }, - RESOURCE_POOL: { - "method": "GET", - "entity_type": AvpEntities.POOL, - }, - RESOURCE_PROVIDER: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.PROVIDERS.value, - }, - RESOURCE_SLA_MISS: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.SLA_MISS.value, - }, - }, - }, - RESOURCE_TASK_INSTANCE: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.TASK_INSTANCE.value, - }, - }, - }, - RESOURCE_TASK_RESCHEDULE: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.TASK_RESCHEDULE.value, - }, - }, - }, - RESOURCE_TRIGGER: { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.TRIGGERS.value, - }, - RESOURCE_VARIABLE: { - "method": "GET", - "entity_type": AvpEntities.VARIABLE, - }, - RESOURCE_XCOM: { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.XCOM.value, - }, - }, - }, -} - - class AwsAuthManager(BaseAuthManager): """ AWS auth manager. @@ -357,6 +205,16 @@ class AwsAuthManager(BaseAuthManager): entity_id=access_view.value, ) + def is_authorized_custom_view( + self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + ): + return self.avp_facade.is_authorized( + method=method, + entity_type=AvpEntities.CUSTOM, + user=user or self.get_user(), + entity_id=resource_name, + ) + def batch_is_authorized_connection( self, requests: Sequence[IsAuthorizedConnectionRequest], @@ -565,12 +423,12 @@ class AwsAuthManager(BaseAuthManager): ] @staticmethod - def _get_menu_item_request(fab_resource_name: str) -> IsAuthorizedRequest: - menu_item_request = _MENU_ITEM_REQUESTS.get(fab_resource_name) - if menu_item_request: - return menu_item_request - else: - raise AirflowException(f"Unknown resource name {fab_resource_name}") + def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest: + return { + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": resource_name, + } def get_parser() -> argparse.ArgumentParser: diff --git a/airflow/providers/amazon/aws/auth_manager/cli/schema.json b/airflow/providers/amazon/aws/auth_manager/cli/schema.json index 6817e725ba..2b0cbdacaa 100644 --- a/airflow/providers/amazon/aws/auth_manager/cli/schema.json +++ b/airflow/providers/amazon/aws/auth_manager/cli/schema.json @@ -25,6 +25,30 @@ "resourceTypes": ["Connection"] } }, + "Custom.DELETE": { + "appliesTo": { + "principalTypes": ["User"], + "resourceTypes": ["Custom"] + } + }, + "Custom.GET": { + "appliesTo": { + "principalTypes": ["User"], + "resourceTypes": ["Custom"] + } + }, + "Custom.POST": { + "appliesTo": { + "principalTypes": ["User"], + "resourceTypes": ["Custom"] + } + }, + "Custom.PUT": { + "appliesTo": { + "principalTypes": ["User"], + "resourceTypes": ["Custom"] + } + }, "Configuration.GET": { "appliesTo": { "principalTypes": ["User"], @@ -97,6 +121,12 @@ "resourceTypes": ["Dataset"] } }, + "Menu.MENU": { + "appliesTo": { + "principalTypes": ["User"], + "resourceTypes": ["Menu"] + } + }, "Pool.DELETE": { "appliesTo": { "principalTypes": ["User"], @@ -155,8 +185,10 @@ "entityTypes": { "Configuration": {}, "Connection": {}, + "Custom": {}, "Dag": {}, "Dataset": {}, + "Menu": {}, "Pool": {}, "Group": {}, "User": { diff --git a/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py b/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py index f6d3ecff82..c291fe6997 100644 --- a/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py +++ b/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py @@ -56,7 +56,7 @@ def _check_action_and_resource(sm: FabAirflowSecurityManagerOverride, perms: lis raise BadRequest(detail=f"The specified resource: {resource!r} was not found") -@requires_access_custom_view(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE) +@requires_access_custom_view("GET", permissions.RESOURCE_ROLE) def get_role(*, role_name: str) -> APIResponse: """Get role.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) @@ -66,7 +66,7 @@ def get_role(*, role_name: str) -> APIResponse: return role_schema.dump(role) -@requires_access_custom_view(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE) +@requires_access_custom_view("GET", permissions.RESOURCE_ROLE) @format_parameters({"limit": check_limit}) def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None) -> APIResponse: """Get roles.""" @@ -94,7 +94,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None) return role_collection_schema.dump(RoleCollection(roles=roles, total_entries=total_entries)) -@requires_access_custom_view(permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION) +@requires_access_custom_view("GET", permissions.RESOURCE_ACTION) @format_parameters({"limit": check_limit}) def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse: """Get permissions.""" @@ -106,7 +106,7 @@ def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse: return action_collection_schema.dump(ActionCollection(actions=actions, total_entries=total_entries)) -@requires_access_custom_view(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE) +@requires_access_custom_view("DELETE", permissions.RESOURCE_ROLE) def delete_role(*, role_name: str) -> APIResponse: """Delete a role.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) @@ -118,7 +118,7 @@ def delete_role(*, role_name: str) -> APIResponse: return NoContent, HTTPStatus.NO_CONTENT -@requires_access_custom_view(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE) +@requires_access_custom_view("PUT", permissions.RESOURCE_ROLE) def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse: """Update a role.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) @@ -151,7 +151,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse return role_schema.dump(role) -@requires_access_custom_view(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE) +@requires_access_custom_view("POST", permissions.RESOURCE_ROLE) def post_role() -> APIResponse: """Create a new role.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) diff --git a/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py b/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py index 9e2726d2bb..61bf8e35e6 100644 --- a/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py +++ b/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py @@ -44,7 +44,7 @@ if TYPE_CHECKING: from airflow.providers.fab.auth_manager.models import Role -@requires_access_custom_view(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER) +@requires_access_custom_view("GET", permissions.RESOURCE_USER) def get_user(*, username: str) -> APIResponse: """Get a user.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) @@ -54,7 +54,7 @@ def get_user(*, username: str) -> APIResponse: return user_collection_item_schema.dump(user) -@requires_access_custom_view(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER) +@requires_access_custom_view("GET", permissions.RESOURCE_USER) @format_parameters({"limit": check_limit}) def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) -> APIResponse: """Get users.""" @@ -86,7 +86,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) -> return user_collection_schema.dump(UserCollection(users=users, total_entries=total_entries)) -@requires_access_custom_view(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_USER) +@requires_access_custom_view("POST", permissions.RESOURCE_USER) def post_user() -> APIResponse: """Create a new user.""" try: @@ -129,7 +129,7 @@ def post_user() -> APIResponse: return user_schema.dump(user) -@requires_access_custom_view(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_USER) +@requires_access_custom_view("PUT", permissions.RESOURCE_USER) def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: """Update a user.""" try: @@ -198,7 +198,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: return user_schema.dump(user) -@requires_access_custom_view(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER) +@requires_access_custom_view("DELETE", permissions.RESOURCE_USER) def delete_user(*, username: str) -> APIResponse: """Delete a user.""" security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index 60d95698b6..7533a214aa 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -268,11 +268,12 @@ class FabAuthManager(BaseAuthManager): ) def is_authorized_custom_view( - self, *, fab_action_name: str, fab_resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None ): if not user: user = self.get_user() - return (fab_action_name, fab_resource_name) in self._get_user_permissions(user) + fab_action_name = get_fab_action_from_method_map()[method] + return (fab_action_name, resource_name) in self._get_user_permissions(user) @provide_session def get_permitted_dag_ids( diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index e70843ae32..c63de34068 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -334,11 +334,11 @@ class AirflowSecurityManagerV2(LoggingMixin): # least one dropdown child return self._is_authorized_category_menu(fab_resource_name) else: - # This means the page the user is trying to access is specific to the auth manager used - # Example: the user list view in FabAuthManager + # The user is trying to access a page specific to the auth manager + # (e.g. the user list view in FabAuthManager) or a page defined in a plugin return lambda action, resource_pk, user: get_auth_manager().is_authorized_custom_view( - fab_action_name=action, - fab_resource_name=fab_resource_name, + method=get_method_from_fab_action_map()[action], + resource_name=fab_resource_name, user=user, ) diff --git a/docs/apache-airflow/core-concepts/auth-manager.rst b/docs/apache-airflow/core-concepts/auth-manager.rst index 4e3446acaa..aaead4a2b3 100644 --- a/docs/apache-airflow/core-concepts/auth-manager.rst +++ b/docs/apache-airflow/core-concepts/auth-manager.rst @@ -97,6 +97,7 @@ Let's go over the different parameters used by most of these methods. * ``POST``: Can the user create a resource? * ``PUT``: Can the user modify the resource? * ``DELETE``: Can the user delete the resource? + * ``MENU``: Can the user see the resource in the menu? * ``details``: Optional details about the resource being accessed. * ``user``: The user trying to access the resource. @@ -113,6 +114,7 @@ These authorization methods are: * ``is_authorized_pool``: Return whether the user is authorized to access Airflow pools. Some details about the pool can be provided (e.g. the pool name). * ``is_authorized_variable``: Return whether the user is authorized to access Airflow variables. Some details about the variable can be provided (e.g. the variable key). * ``is_authorized_view``: Return whether the user is authorized to access a specific view in Airflow. The view is specified through ``access_view`` (e.g. ``AccessView.CLUSTER_ACTIVITY``). +* ``is_authorized_custom_view``: Return whether the user is authorized to access a specific view not defined in Airflow. This view can be provided by the auth manager itself or a plugin defined by the user. Optional methods recommended to override for optimization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/newsfragments/37915.significant.rst b/newsfragments/37915.significant.rst new file mode 100644 index 0000000000..27702fd77f --- /dev/null +++ b/newsfragments/37915.significant.rst @@ -0,0 +1 @@ +The method ``is_authorized_custom_view`` from ``BaseAuthManager`` is now abstract. All sub classes must implement this method. diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index d05bb50dd4..04191c4838 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -31,7 +31,6 @@ from airflow.auth.managers.models.resource_details import ( VariableDetails, ) from airflow.exceptions import AirflowException -from airflow.security import permissions from airflow.www.extensions.init_appbuilder import init_appbuilder from airflow.www.security_manager import AirflowSecurityManagerV2 @@ -95,6 +94,11 @@ class EmptyAuthManager(BaseAuthManager): def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = None) -> bool: raise NotImplementedError() + def is_authorized_custom_view( + self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + ): + raise NotImplementedError() + def is_logged_in(self) -> bool: raise NotImplementedError() @@ -154,13 +158,6 @@ class TestBaseAuthManager: def test_get_url_user_profile_return_none(self, auth_manager): assert auth_manager.get_url_user_profile() is None - def test_is_authorized_custom_view_raise_exception(self, auth_manager): - with pytest.raises(AirflowException, match="The resource `.*` does not exist in the environment."): - auth_manager.is_authorized_custom_view( - fab_action_name=permissions.ACTION_CAN_READ, - fab_resource_name=permissions.RESOURCE_MY_PASSWORD, - ) - @pytest.mark.parametrize( "return_values, expected", [ diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index 9b654199c3..cecd7deb42 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -33,8 +33,8 @@ from airflow.auth.managers.models.resource_details import ( PoolDetails, VariableDetails, ) -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities +from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( AwsSecurityManagerOverride, @@ -47,14 +47,33 @@ from airflow.security.permissions import ( RESOURCE_DATASET, RESOURCE_VARIABLE, ) +from airflow.www import app as application from airflow.www.extensions.init_appbuilder import init_appbuilder from tests.test_utils.config import conf_vars +from tests.test_utils.www import check_content_in_response if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod mock = Mock() +SAML_METADATA_PARSED = { + "idp": { + "entityId": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "singleSignOnService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "singleLogoutService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "x509cert": "<cert>", + }, + "security": {"authnRequestsSigned": False}, + "sp": {"NameIDFormat": "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"}, +} + @pytest.fixture def auth_manager(): @@ -91,6 +110,39 @@ def test_user(): return AwsAuthManagerUser(user_id="test_user_id", groups=[], username="test_username") +@pytest.fixture +def client_admin(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "region_name"): "us-east-1", + ("aws_auth_manager", "saml_metadata_url"): "/saml/metadata", + ("aws_auth_manager", "avp_policy_store_id"): "avp_policy_store_id", + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser, patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + ) as mock_init_saml_auth: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + + auth = Mock() + auth.is_authenticated.return_value = True + auth.get_nameid.return_value = "user_admin_permissions" + auth.get_attributes.return_value = { + "id": ["user_admin_permissions"], + "groups": ["Admin"], + "email": ["email"], + } + mock_init_saml_auth.return_value = auth + yield application.create_app(testing=True) + + class TestAwsAuthManager: def test_avp_facade(self, auth_manager): assert hasattr(auth_manager, "avp_facade") @@ -524,47 +576,48 @@ class TestAwsAuthManager: { "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Connection.GET"}, - "resource": {"entityType": "Airflow::Connection", "entityId": "*"}, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "Connections"}, }, "decision": "DENY", }, { "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Variable.GET"}, - "resource": {"entityType": "Airflow::Variable", "entityId": "*"}, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "Variables"}, }, "decision": "ALLOW", }, { "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Dataset.GET"}, - "resource": {"entityType": "Airflow::Dataset", "entityId": "*"}, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "Datasets"}, }, "decision": "DENY", }, { "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "View.GET"}, - "resource": {"entityType": "Airflow::View", "entityId": "CLUSTER_ACTIVITY"}, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "Cluster Activity"}, }, "decision": "DENY", }, { "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, - "resource": {"entityType": "Airflow::Dag", "entityId": "*"}, - "context": { - "contextMap": { - "dag_entity": { - "string": "AUDIT_LOG", - } - } - }, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "Audit Logs"}, + }, + "decision": "ALLOW", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, + "resource": {"entityType": "Airflow::Menu", "entityId": "CustomPage"}, }, "decision": "ALLOW", }, @@ -581,45 +634,51 @@ class TestAwsAuthManager: MenuItem("Category2", childs=[MenuItem(RESOURCE_DATASET)]), MenuItem(RESOURCE_CLUSTER_ACTIVITY), MenuItem(RESOURCE_AUDIT_LOG), + MenuItem("CustomPage"), ] ) + """ + return { + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": resource_name, + } + """ + auth_manager.avp_facade.get_batch_is_authorized_results.assert_called_once_with( requests=[ { - "method": "GET", - "entity_type": AvpEntities.CONNECTION, + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": "Connections", }, { - "method": "GET", - "entity_type": AvpEntities.VARIABLE, + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": "Variables", }, { - "method": "GET", - "entity_type": AvpEntities.DATASET, + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": "Datasets", }, + {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Cluster Activity"}, + {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Audit Logs"}, { - "method": "GET", - "entity_type": AvpEntities.VIEW, - "entity_id": AccessView.CLUSTER_ACTIVITY.value, - }, - { - "method": "GET", - "entity_type": AvpEntities.DAG, - "context": { - "dag_entity": { - "string": DagAccessEntity.AUDIT_LOG.value, - }, - }, + "method": "MENU", + "entity_type": AvpEntities.MENU, + "entity_id": "CustomPage", }, ], user=test_user, ) - assert len(result) == 2 + assert len(result) == 3 assert result[0].name == "Category1" assert len(result[0].childs) == 1 assert result[0].childs[0].name == RESOURCE_VARIABLE assert result[1].name == RESOURCE_AUDIT_LOG + assert result[2].name == "CustomPage" @patch.object(AwsAuthManager, "get_user") def test_filter_permitted_menu_items_logged_out(self, mock_get_user, auth_manager): @@ -632,16 +691,6 @@ class TestAwsAuthManager: assert result == [] - @patch.object(AwsAuthManager, "get_user") - def test_filter_permitted_menu_items_wrong_menu_item(self, mock_get_user, auth_manager, test_user): - mock_get_user.return_value = test_user - with pytest.raises(AirflowException, match="Unknown resource name"): - auth_manager.filter_permitted_menu_items( - [ - MenuItem("Test"), - ] - ) - @pytest.mark.parametrize( "methods, user", [ @@ -717,3 +766,24 @@ class TestAwsAuthManager: def test_get_cli_commands_return_cli_commands(self, auth_manager): assert len(auth_manager.get_cli_commands()) > 0 + + @pytest.mark.db_test + @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "get_batch_is_authorized_single_result") + @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "get_batch_is_authorized_results") + @patch.object(AwsAuthManagerAmazonVerifiedPermissionsFacade, "is_authorized") + def test_aws_auth_manager_index( + self, + mock_is_authorized, + mock_get_batch_is_authorized_results, + mock_get_batch_is_authorized_single_result, + client_admin, + ): + """ + Load the index page using AWS auth manager. Mock all interactions with Amazon Verified Permissions. + """ + mock_is_authorized.return_value = True + mock_get_batch_is_authorized_results.return_value = [] + mock_get_batch_is_authorized_single_result.return_value = {"decision": "ALLOW"} + with client_admin.test_client() as client: + response = client.get("/login_callback", follow_redirects=True) + check_content_in_response("<h2>DAGs</h2>", response, 200) diff --git a/tests/providers/fab/auth_manager/test_fab_auth_manager.py b/tests/providers/fab/auth_manager/test_fab_auth_manager.py index 17b912e252..773651fafe 100644 --- a/tests/providers/fab/auth_manager/test_fab_auth_manager.py +++ b/tests/providers/fab/auth_manager/test_fab_auth_manager.py @@ -17,6 +17,7 @@ from __future__ import annotations from itertools import chain +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock @@ -50,6 +51,9 @@ from airflow.security.permissions import ( ) from airflow.www.extensions.init_appbuilder import init_appbuilder +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod + IS_AUTHORIZED_METHODS_SIMPLE = { "is_authorized_configuration": RESOURCE_CONFIG, "is_authorized_connection": RESOURCE_CONNECTION, @@ -367,6 +371,37 @@ class TestFabAuthManager: result = auth_manager.is_authorized_view(access_view=access_view, user=user) assert result == expected_result + @pytest.mark.parametrize( + "method, resource_name, user_permissions, expected_result", + [ + ( + "GET", + "custom_resource", + [(ACTION_CAN_READ, "custom_resource")], + True, + ), + ( + "GET", + "custom_resource", + [(ACTION_CAN_EDIT, "custom_resource")], + False, + ), + ( + "GET", + "custom_resource", + [(ACTION_CAN_READ, "custom_resource2")], + False, + ), + ], + ) + def test_is_authorized_custom_view( + self, method: ResourceMethod, resource_name: str, user_permissions, expected_result, auth_manager + ): + user = Mock() + user.perms = user_permissions + result = auth_manager.is_authorized_custom_view(method=method, resource_name=resource_name, user=user) + assert result == expected_result + @pytest.mark.db_test def test_security_manager_return_fab_security_manager_override(self, auth_manager_with_appbuilder): assert isinstance(auth_manager_with_appbuilder.security_manager, FabAirflowSecurityManagerOverride) diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index a40db5da66..fecd5c4428 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -41,6 +41,7 @@ from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager from airflow.providers.fab.auth_manager.models import User, assoc_permission_role from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser from airflow.security import permissions +from airflow.security.permissions import ACTION_CAN_READ from airflow.www import app as application from airflow.www.auth import get_access_denied_message from airflow.www.extensions.init_auth_manager import get_auth_manager @@ -547,7 +548,7 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( def test_has_access(security_manager): user = mock.MagicMock() - action_name = "action" + action_name = ACTION_CAN_READ resource_name = "resource" user.perms = [(action_name, resource_name)] assert security_manager.has_access(action_name, resource_name, user)