This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit e2e89668d46363bc0ecfa31058d02bc3e041de88 Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Tue Nov 21 12:35:10 2023 -0500 Implement login and logout in AWS auth manager (#35488) --- CONTRIBUTING.rst | 6 +- Dockerfile | 4 +- Dockerfile.ci | 4 +- INSTALL | 6 +- airflow/auth/managers/base_auth_manager.py | 32 +++-- airflow/auth/managers/fab/fab_auth_manager.py | 16 +-- airflow/auth/managers/fab/models/__init__.py | 3 + airflow/auth/managers/models/base_user.py | 7 +- .../amazon/aws/auth_manager/__init__.py} | 17 --- .../amazon/aws/auth_manager/aws_auth_manager.py | 143 +++++++++++++++++++ .../amazon/aws/auth_manager/constants.py} | 20 +-- .../aws/auth_manager/security_manager/__init__.py} | 17 --- .../aws_security_manager_override.py} | 25 ++-- airflow/providers/amazon/aws/auth_manager/user.py | 51 +++++++ .../amazon/aws/auth_manager/views/__init__.py} | 17 --- .../amazon/aws/auth_manager/views/auth.py | 149 ++++++++++++++++++++ airflow/providers/amazon/provider.yaml | 25 ++++ airflow/www/extensions/init_appbuilder.py | 2 +- airflow/www/extensions/init_auth_manager.py | 6 +- airflow/www/extensions/init_security.py | 4 +- airflow/www/security_manager.py | 33 +++-- airflow/www/utils.py | 5 +- docs/apache-airflow/extra-packages-ref.rst | 2 + docs/docker-stack/changelog.rst | 3 + docs/spelling_wordlist.txt | 1 + scripts/docker/install_os_dependencies.sh | 4 +- setup.py | 7 + .../endpoints/test_forward_to_fab_endpoint.py | 2 +- tests/auth/managers/fab/test_fab_auth_manager.py | 38 +---- tests/auth/managers/test_base_auth_manager.py | 48 +++++-- .../providers/amazon/aws/auth_manager/__init__.py | 17 --- .../aws/auth_manager/security_manager/__init__.py | 17 --- .../test_aws_security_manager_override.py | 56 ++++++++ .../aws/auth_manager/test_aws_auth_manager.py | 114 +++++++++++++++ .../amazon/aws/auth_manager/test_constants.py | 23 +-- .../providers/amazon/aws/auth_manager/test_user.py | 32 +++-- .../amazon/aws/auth_manager/views/__init__.py | 17 --- .../amazon/aws/auth_manager/views/test_auth.py | 156 +++++++++++++++++++++ 38 files changed, 865 insertions(+), 264 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 5a5fae65ca..1114c16074 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -679,9 +679,9 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, oracle, otel, pagerduty, pandas, papermill, password, pgvector, pinecone, pinot, plexus, -postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, -singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, -vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk +postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, saml, segment, sendgrid, sentry, +sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, +trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk .. END EXTRAS HERE Provider packages diff --git a/Dockerfile b/Dockerfile index 94794b3735..b9b358d0c7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,7 +97,7 @@ function get_dev_apt_deps() { DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \ freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev libgeos-dev \ libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \ -libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \ +libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client pkgconf sasl2-bin \ software-properties-common sqlite3 sudo unixodbc unixodbc-dev" export DEV_APT_DEPS fi @@ -123,7 +123,7 @@ function get_runtime_apt_deps() { if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \ curl dumb-init freetds-bin gosu krb5-user libgeos-dev \ -ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \ +ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales ${debian_version_apt_deps} \ lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc" export RUNTIME_APT_DEPS fi diff --git a/Dockerfile.ci b/Dockerfile.ci index 81a9d424a0..c3e5e2e3a0 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -57,7 +57,7 @@ function get_dev_apt_deps() { DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \ freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev libgeos-dev \ libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \ -libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \ +libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client pkgconf sasl2-bin \ software-properties-common sqlite3 sudo unixodbc unixodbc-dev" export DEV_APT_DEPS fi @@ -83,7 +83,7 @@ function get_runtime_apt_deps() { if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \ curl dumb-init freetds-bin gosu krb5-user libgeos-dev \ -ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \ +ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales ${debian_version_apt_deps} \ lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc" export RUNTIME_APT_DEPS fi diff --git a/INSTALL b/INSTALL index 7742c006bc..0ac0ac33d0 100644 --- a/INSTALL +++ b/INSTALL @@ -105,9 +105,9 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, oracle, otel, pagerduty, pandas, papermill, password, pgvector, pinecone, pinot, plexus, -postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment, sendgrid, sentry, sftp, -singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, trino, -vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk +postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, saml, segment, sendgrid, sentry, +sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau, tabular, telegram, +trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk # END EXTRAS HERE # For installing Airflow in development environments - see CONTRIBUTING.rst diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index c158ea8481..f50e40082b 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -33,7 +33,6 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from connexion import FlaskApi - from flask import Flask from sqlalchemy.orm import Session from airflow.auth.managers.models.base_user import BaseUser @@ -64,11 +63,12 @@ class BaseAuthManager(LoggingMixin): Class to derive in order to implement concrete auth managers. Auth managers are responsible for any user management related operation such as login, logout, authz, ... + + :param appbuilder: the flask app builder """ - def __init__(self, app: Flask, appbuilder: AirflowAppBuilder) -> None: + def __init__(self, appbuilder: AirflowAppBuilder) -> None: super().__init__() - self.app = app self.appbuilder = appbuilder @staticmethod @@ -83,21 +83,29 @@ class BaseAuthManager(LoggingMixin): """Return API endpoint(s) definition for the auth manager.""" return None - @abstractmethod def get_user_name(self) -> str: """Return the username associated to the user in session.""" + user = self.get_user() + if not user: + self.log.error("Calling 'get_user_name()' but the user is not signed in.") + raise AirflowException("The user must be signed in.") + return user.get_name() - @abstractmethod def get_user_display_name(self) -> str: """Return the user's display name associated to the user in session.""" + return self.get_user_name() @abstractmethod - def get_user(self) -> BaseUser: + def get_user(self) -> BaseUser | None: """Return the user associated to the user in session.""" - @abstractmethod def get_user_id(self) -> str: """Return the user ID associated to the user in session.""" + user = self.get_user() + if not user: + self.log.error("Calling 'get_user_id()' but the user is not signed in.") + raise AirflowException("The user must be signed in.") + return str(user.get_id()) def init(self) -> None: """ @@ -384,9 +392,13 @@ class BaseAuthManager(LoggingMixin): def get_url_logout(self) -> str: """Return the logout page url.""" - @abstractmethod def get_url_user_profile(self) -> str | None: - """Return the url to a page displaying info about the current user.""" + """ + Return the url to a page displaying info about the current user. + + By default, return None. + """ + return None @cached_property def security_manager(self) -> AirflowSecurityManagerV2: @@ -394,7 +406,7 @@ class BaseAuthManager(LoggingMixin): Return the security manager. By default, Airflow comes with the default security manager - airflow.www.security_manager.AirflowSecurityManagerV2. The auth manager might need to extend this + ``airflow.www.security_manager.AirflowSecurityManagerV2``. The auth manager might need to extend this default security manager for its own purposes. By default, return the default AirflowSecurityManagerV2. diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 6305406d2a..5224509039 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -170,26 +170,12 @@ class FabAuthManager(BaseAuthManager): last_name = user.last_name.strip() if isinstance(user.last_name, str) else "" return f"{first_name} {last_name}".strip() - def get_user_name(self) -> str: - """ - Return the username associated to the user in session. - - For backward compatibility reasons, the username in FAB auth manager can be any of username, - email, or the database user ID. - """ - user = self.get_user() - return user.username or user.email or self.get_user_id() - def get_user(self) -> User: """Return the user associated to the user in session.""" from flask_login import current_user return current_user - def get_user_id(self) -> str: - """Return the user ID associated to the user in session.""" - return str(self.get_user().get_id()) - def init(self) -> None: """Run operations when Airflow is initializing.""" self._sync_appbuilder_roles() @@ -351,7 +337,7 @@ class FabAuthManager(BaseAuthManager): from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride from airflow.www.security import AirflowSecurityManager - sm_from_config = self.app.config.get("SECURITY_MANAGER_CLASS") + sm_from_config = self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS") if sm_from_config: if not issubclass(sm_from_config, AirflowSecurityManager): raise Exception( diff --git a/airflow/auth/managers/fab/models/__init__.py b/airflow/auth/managers/fab/models/__init__.py index 28de7840ec..1fe967d8c7 100644 --- a/airflow/auth/managers/fab/models/__init__.py +++ b/airflow/auth/managers/fab/models/__init__.py @@ -228,6 +228,9 @@ class User(Model, BaseUser): def get_id(self): return self.id + def get_name(self) -> str: + return self.username or self.email or self.user_id + def get_full_name(self): return f"{self.first_name} {self.last_name}" diff --git a/airflow/auth/managers/models/base_user.py b/airflow/auth/managers/models/base_user.py index 7cb45d8389..93163ee06d 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/auth/managers/models/base_user.py @@ -24,10 +24,13 @@ class BaseUser: """User model interface.""" @property - @abstractmethod def is_active(self) -> bool: - ... + return True @abstractmethod def get_id(self) -> str: ... + + @abstractmethod + def get_name(self) -> str: + ... diff --git a/airflow/auth/managers/models/base_user.py b/airflow/providers/amazon/aws/auth_manager/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to airflow/providers/amazon/aws/auth_manager/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/providers/amazon/aws/auth_manager/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py new file mode 100644 index 0000000000..fcf3caacb9 --- /dev/null +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING + +from flask import session, url_for + +from airflow.configuration import conf +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.amazon.aws.auth_manager.constants import ( + CONF_ENABLE_KEY, + CONF_SECTION_NAME, +) +from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( + AwsSecurityManagerOverride, +) + +try: + from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +except ImportError: + raise AirflowOptionalProviderFeatureException( + "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0" + ) + +if TYPE_CHECKING: + from airflow.auth.managers.models.base_user import BaseUser + from airflow.auth.managers.models.resource_details import ( + AccessView, + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, + ) + from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser + from airflow.www.extensions.init_appbuilder import AirflowAppBuilder + + +class AwsAuthManager(BaseAuthManager): + """ + AWS auth manager. + + Leverages AWS services such as Amazon Identity Center and Amazon Verified Permissions to perform + authentication and authorization in Airflow. + + :param appbuilder: the flask app builder + """ + + def __init__(self, appbuilder: AirflowAppBuilder) -> None: + super().__init__(appbuilder) + enable = conf.getboolean(CONF_SECTION_NAME, CONF_ENABLE_KEY) + if not enable: + raise NotImplementedError( + "The AWS auth manager is currently being built. It is not finalized. It is not intended to be used yet." + ) + + def get_user(self) -> AwsAuthManagerUser | None: + return session["aws_user"] if self.is_logged_in() else None + + def is_logged_in(self) -> bool: + return "aws_user" in session + + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + return self.is_logged_in() + + def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + return self.is_logged_in() + + def is_authorized_connection( + self, + *, + method: ResourceMethod, + details: ConnectionDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + return self.is_logged_in() + + def is_authorized_dag( + self, + *, + method: ResourceMethod, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + return self.is_logged_in() + + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self.is_logged_in() + + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self.is_logged_in() + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self.is_logged_in() + + def is_authorized_view( + self, + *, + access_view: AccessView, + user: BaseUser | None = None, + ) -> bool: + return self.is_logged_in() + + def get_url_login(self, **kwargs) -> str: + return url_for("AwsAuthManagerAuthenticationViews.login") + + def get_url_logout(self) -> str: + return url_for("AwsAuthManagerAuthenticationViews.logout") + + @cached_property + def security_manager(self) -> AwsSecurityManagerOverride: + return AwsSecurityManagerOverride(self.appbuilder) diff --git a/airflow/auth/managers/models/base_user.py b/airflow/providers/amazon/aws/auth_manager/constants.py similarity index 78% copy from airflow/auth/managers/models/base_user.py copy to airflow/providers/amazon/aws/auth_manager/constants.py index 7cb45d8389..f2f9c1da07 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/providers/amazon/aws/auth_manager/constants.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod +# Configuration keys +from __future__ import annotations -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... +CONF_SECTION_NAME = "aws_auth_manager" +CONF_SAML_METADATA_URL_KEY = "saml_metadata_url" +CONF_ENABLE_KEY = "enable" diff --git a/airflow/auth/managers/models/base_user.py b/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/airflow/auth/managers/models/base_user.py b/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py similarity index 52% copy from airflow/auth/managers/models/base_user.py copy to airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py index 7cb45d8389..9d6432f1c6 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,17 +16,21 @@ # under the License. from __future__ import annotations -from abc import abstractmethod +from airflow.exceptions import AirflowOptionalProviderFeatureException + +try: + from airflow.www.security_manager import AirflowSecurityManagerV2 +except ImportError: + raise AirflowOptionalProviderFeatureException( + "Failed to import AirflowSecurityManagerV2. This feature is only available in Airflow versions >= 2.8.0" + ) -class BaseUser: - """User model interface.""" +class AwsSecurityManagerOverride(AirflowSecurityManagerV2): + """The security manager override specific to AWS auth manager.""" - @property - @abstractmethod - def is_active(self) -> bool: - ... + def register_views(self): + """Register views specific to AWS auth manager.""" + from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews - @abstractmethod - def get_id(self) -> str: - ... + self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews()) diff --git a/airflow/providers/amazon/aws/auth_manager/user.py b/airflow/providers/amazon/aws/auth_manager/user.py new file mode 100644 index 0000000000..68699783d6 --- /dev/null +++ b/airflow/providers/amazon/aws/auth_manager/user.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.exceptions import AirflowOptionalProviderFeatureException + +try: + from airflow.auth.managers.models.base_user import BaseUser +except ImportError: + raise AirflowOptionalProviderFeatureException( + "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0" + ) + + +class AwsAuthManagerUser(BaseUser): + """ + User model for users managed by the AWS Auth Manager. + + :param user_id: The user ID. + :param groups: The groups the user belongs to. + :param username: The username of the user. + :param email: The email of the user. + """ + + def __init__( + self, *, user_id: str, groups: list[str], username: str | None = None, email: str | None = None + ) -> None: + self.user_id = user_id + self.groups = groups + self.username = username + self.email = email + + def get_id(self) -> str: + return self.user_id + + def get_name(self) -> str: + return self.username or self.email or self.user_id diff --git a/airflow/auth/managers/models/base_user.py b/airflow/providers/amazon/aws/auth_manager/views/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to airflow/providers/amazon/aws/auth_manager/views/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/airflow/providers/amazon/aws/auth_manager/views/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/airflow/providers/amazon/aws/auth_manager/views/auth.py b/airflow/providers/amazon/aws/auth_manager/views/auth.py new file mode 100644 index 0000000000..eb73a3805b --- /dev/null +++ b/airflow/providers/amazon/aws/auth_manager/views/auth.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from functools import cached_property + +from flask import make_response, redirect, request, session, url_for +from flask_appbuilder import expose + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser +from airflow.www.app import csrf +from airflow.www.views import AirflowBaseView + +try: + from onelogin.saml2.auth import OneLogin_Saml2_Auth + from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser +except ImportError: + raise ImportError( + "AWS auth manager requires the python3-saml library but it is not installed by default. " + "Please install the python3-saml library by running: " + "pip install apache-airflow-providers-amazon[python3-saml]" + ) + + +class AwsAuthManagerAuthenticationViews(AirflowBaseView): + """ + Views specific to AWS auth manager authentication mechanism. + + Some code below is inspired from + https://github.com/SAML-Toolkits/python3-saml/blob/6988bdab7a203abfe8dc264992f7e350c67aef3d/demo-flask/index.py + """ + + @cached_property + def idp_data(self) -> dict: + saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY) + return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url) + + @expose("/login") + def login(self): + """Start login process.""" + saml_auth = self._init_saml_auth() + return redirect(saml_auth.login()) + + @expose("/logout") + def logout(self): + """Start logout process.""" + session.clear() + saml_auth = self._init_saml_auth() + + return redirect(saml_auth.logout()) + + @csrf.exempt + @expose("/login_callback", methods=("GET", "POST")) + def login_callback(self): + """ + Callback where the user is redirected to after successful login. + + CSRF protection needs to be disabled otherwise the callback won't work. + """ + saml_auth = self._init_saml_auth() + saml_auth.process_response() + errors = saml_auth.get_errors() + is_authenticated = saml_auth.is_authenticated() + if not is_authenticated: + error_reason = saml_auth.get_last_error_reason() + logging.error("Failed to authenticate") + logging.error("Errors: %s", errors) + logging.error("Error reason: %s", error_reason) + raise AirflowException(f"Failed to authenticate: {error_reason}") + + attributes = saml_auth.get_attributes() + user = AwsAuthManagerUser( + user_id=attributes["id"][0], + groups=attributes["groups"], + username=saml_auth.get_nameid(), + email=attributes["email"][0], + ) + session["aws_user"] = user + + return redirect(url_for("Airflow.index")) + + @csrf.exempt + @expose("/logout_callback", methods=("GET", "POST")) + def logout_callback(self): + raise NotImplementedError("AWS Identity center does not support SLO (Single Logout Service)") + + @expose("/login_metadata") + def login_metadata(self): + saml_auth = self._init_saml_auth() + settings = saml_auth.get_settings() + metadata = settings.get_sp_metadata() + errors = settings.validate_metadata(metadata) + + if len(errors) == 0: + resp = make_response(metadata, 200) + resp.headers["Content-Type"] = "text/xml" + else: + resp = make_response(", ".join(errors), 500) + return resp + + @staticmethod + def _prepare_flask_request() -> dict: + return { + "https": "on" if request.scheme == "https" else "off", + "http_host": request.host, + "script_name": request.path, + "get_data": request.args.copy(), + "post_data": request.form.copy(), + } + + def _init_saml_auth(self) -> OneLogin_Saml2_Auth: + request_data = self._prepare_flask_request() + base_url = conf.get(section="webserver", key="base_url") + settings = { + # We want to keep this flag on in case of errors. + # It provides an error reasons, if turned off, it does not + "debug": True, + "sp": { + "entityId": f"{base_url}/login_metadata", + "assertionConsumerService": { + "url": f"{base_url}/login_callback", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", + }, + "singleLogoutService": { + "url": f"{base_url}/logout_callback", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + }, + } + merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(settings, self.idp_data) + return OneLogin_Saml2_Auth(request_data, merged_settings) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 8c0c13fbc6..6399961543 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -746,6 +746,9 @@ additional-extras: - name: s3fs dependencies: - s3fs>=2023.9.2 + - name: python3-saml + dependencies: + - python3-saml>=1.16.0 config: aws: @@ -897,6 +900,28 @@ config: type: boolean example: "True" default: "True" + aws_auth_manager: + description: | + This section only applies if you are using the AwsAuthManager. In other words, if you set + ``[core] auth_manager = airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager`` in + Airflow's configuration. + options: + enable: + description: | + AWS auth manager is not ready to be used. Turn on this flag to use it anyway. + Do that at your own risk since the AWS auth manager is not in an usable state. + version_added: 8.12.0 + type: boolean + example: "True" + default: "False" + saml_metadata_url: + description: | + SAML metadata XML file provided by AWS Identity Center. + This URL can be found in the AWS Identity Center console. Required. + version_added: 8.12.0 + type: string + example: "https://portal.sso.<region>.amazonaws.com/saml/metadata/XXXXXXXXXX" + default: ~ executors: - airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor diff --git a/airflow/www/extensions/init_appbuilder.py b/airflow/www/extensions/init_appbuilder.py index ce87fdc76c..ddb44200f9 100644 --- a/airflow/www/extensions/init_appbuilder.py +++ b/airflow/www/extensions/init_appbuilder.py @@ -214,7 +214,7 @@ class AirflowAppBuilder: self._addon_managers = app.config["ADDON_MANAGERS"] self.session = session - auth_manager = init_auth_manager(app, self) + auth_manager = init_auth_manager(self) self.sm = auth_manager.security_manager self.bm = BabelManager(self) self._add_global_static() diff --git a/airflow/www/extensions/init_auth_manager.py b/airflow/www/extensions/init_auth_manager.py index 63ec043c10..d29fb6944e 100644 --- a/airflow/www/extensions/init_auth_manager.py +++ b/airflow/www/extensions/init_auth_manager.py @@ -22,8 +22,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException if TYPE_CHECKING: - from flask import Flask - from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.www.extensions.init_appbuilder import AirflowAppBuilder @@ -47,7 +45,7 @@ def get_auth_manager_cls() -> type[BaseAuthManager]: return auth_manager_cls -def init_auth_manager(app: Flask, appbuilder: AirflowAppBuilder) -> BaseAuthManager: +def init_auth_manager(appbuilder: AirflowAppBuilder) -> BaseAuthManager: """ Initialize the auth manager. @@ -55,7 +53,7 @@ def init_auth_manager(app: Flask, appbuilder: AirflowAppBuilder) -> BaseAuthMana """ global auth_manager auth_manager_cls = get_auth_manager_cls() - auth_manager = auth_manager_cls(app, appbuilder) + auth_manager = auth_manager_cls(appbuilder) return auth_manager diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py index 21b71e8f84..a7739e3231 100644 --- a/airflow/www/extensions/init_security.py +++ b/airflow/www/extensions/init_security.py @@ -19,7 +19,7 @@ from __future__ import annotations import logging from importlib import import_module -from flask import g, redirect, request +from flask import redirect, request from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException @@ -72,5 +72,5 @@ def init_check_user_active(app): url_logout = get_auth_manager().get_url_logout() if request.path == url_logout: return - if get_auth_manager().is_logged_in() and not g.user.is_active: + if get_auth_manager().is_logged_in() and not get_auth_manager().get_user().is_active: return redirect(url_logout) diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 35e2a63e64..78d8ef00fe 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -112,6 +112,10 @@ class AirflowSecurityManagerV2(LoggingMixin): limiter.init_app(self.appbuilder.get_app) return limiter + def register_views(self): + """Allow auth managers to register their own views. By default, do nothing.""" + pass + def has_access( self, action_name: str, resource_name: str, user=None, resource_pk: str | None = None ) -> bool: @@ -168,18 +172,6 @@ class AirflowSecurityManagerV2(LoggingMixin): cost=limit.cost, )(baseview.blueprint) - def add_permissions_view(self, base_action_names, resource_name): - raise NotImplementedError("Sync FAB permissions is only available with the FAB auth manager") - - def add_permissions_menu(self, resource_name): - raise NotImplementedError("Sync FAB permissions is only available with the FAB auth manager") - - def get_action(self, name: str) -> Action: - raise NotImplementedError("Only available when FAB auth manager is used") - - def get_resource(self, name: str) -> Resource: - raise NotImplementedError("Only available when FAB auth manager is used") - @cached_property @provide_session def _auth_manager_is_authorized_map( @@ -355,3 +347,20 @@ class AirflowSecurityManagerV2(LoggingMixin): return lambda action, resource_pk, user: any( self._get_auth_manager_is_authorized_method(fab_resource_name=item) for item in items ) + + """ + The following methods are specific to FAB auth manager. They still need to be "present" in the main + security manager class, but they do nothing. + """ + + def get_action(self, name: str) -> Action: + raise NotImplementedError("Only available if FAB auth manager is used") + + def get_resource(self, name: str) -> Resource: + raise NotImplementedError("Only available if FAB auth manager is used") + + def add_permissions_view(self, base_action_names, resource_name): + pass + + def add_permissions_menu(self, resource_name): + pass diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 5fe8a7ccb7..16d0950989 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -940,10 +940,7 @@ class UIAlert: if self.roles: current_user = get_auth_manager().get_user() if current_user is not None: - if not hasattr(current_user, "roles"): - # If the user does not contain "roles" in its model, return False - return False - user_roles = {r.name for r in current_user.roles} + user_roles = {r.name for r in getattr(current_user, "roles", [])} elif "AUTH_ROLE_PUBLIC" in appbuilder.get_app.config: # If the current_user is anonymous, assign AUTH_ROLE_PUBLIC role (if it exists) to them user_roles = {appbuilder.get_app.config["AUTH_ROLE_PUBLIC"]} diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index 3060d7c2db..32e6d892bc 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -70,6 +70,8 @@ python dependencies for the provided package. +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | s3fs | ``pip install 'apache-airflow[s3fs]'`` | Support for S3 as Airflow FS | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ +| saml | ``pip install 'apache-airflow[saml]'`` | Support for SAML authentication in Airflow | ++---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | statsd | ``pip install 'apache-airflow[statsd]'`` | Needed by StatsD metrics | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | virtualenv | ``pip install 'apache-airflow[virtualenv]'`` | Running python tasks in local virtualenv | diff --git a/docs/docker-stack/changelog.rst b/docs/docker-stack/changelog.rst index 0f752cb189..ad71d49221 100644 --- a/docs/docker-stack/changelog.rst +++ b/docs/docker-stack/changelog.rst @@ -39,6 +39,9 @@ Airflow 2.8 * 2.8.0 + * Add ``libxmlsec1`` and ``libxmlsec1-dev`` libraries to dev PROD image and ``libxmlsec1`` library to runtime PROD + image as it is required by ``python3-saml`` library. + * The image is based on ``Debian Bookworm`` in 2.8.0 rather than ``Debian Bullseye``. This might cause some problems when building custom images. You are advised to make sure your system level dependencies are working with ``Debian Bookworm``. While all reference images of Airflow 2.8.0 are built on ``Debian Bookworm``, diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 02380d9884..56198798b4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1356,6 +1356,7 @@ Sagemaker sagemaker salesforce samesite +saml sanitization sas Sasl diff --git a/scripts/docker/install_os_dependencies.sh b/scripts/docker/install_os_dependencies.sh index cb6cb692fb..318add1435 100644 --- a/scripts/docker/install_os_dependencies.sh +++ b/scripts/docker/install_os_dependencies.sh @@ -39,7 +39,7 @@ function get_dev_apt_deps() { DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \ freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev libgeos-dev \ libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \ -libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \ +libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client pkgconf sasl2-bin \ software-properties-common sqlite3 sudo unixodbc unixodbc-dev" export DEV_APT_DEPS fi @@ -65,7 +65,7 @@ function get_runtime_apt_deps() { if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \ curl dumb-init freetds-bin gosu krb5-user libgeos-dev \ -ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \ +ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales ${debian_version_apt_deps} \ lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc" export RUNTIME_APT_DEPS fi diff --git a/setup.py b/setup.py index 46a7e68c5a..bc624e21fa 100644 --- a/setup.py +++ b/setup.py @@ -527,6 +527,11 @@ s3fs = [ "s3fs>=2023.9.2", ] +saml = [ + # This is required for support of SAML which might be used by some providers (e.g. Amazon) + "python3-saml>=1.16.0", +] + def get_provider_dependencies(provider_name: str) -> list[str]: if provider_name not in PROVIDER_DEPENDENCIES: @@ -554,6 +559,7 @@ devel = get_unique_dependency_list( pandas, password, s3fs, + saml, ] ) @@ -600,6 +606,7 @@ CORE_EXTRAS_DEPENDENCIES: dict[str, list[str]] = { "password": password, "rabbitmq": rabbitmq, "s3fs": s3fs, + "saml": saml, "sentry": sentry, "statsd": statsd, "virtualenv": virtualenv, diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py index 5560a4d877..4b8294eb6d 100644 --- a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py +++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py @@ -130,7 +130,7 @@ class TestFABforwarding: class TestFABRoleForwarding(TestFABforwarding): @mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager") def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager): - mock_get_auth_manager.return_value = BaseAuthManager(self.app, self.app.appbuilder) + mock_get_auth_manager.return_value = BaseAuthManager(self.app.appbuilder) response = self.client.get("api/v1/roles", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 400 assert ( diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py index 934293606a..74d20ab13b 100644 --- a/tests/auth/managers/fab/test_fab_auth_manager.py +++ b/tests/auth/managers/fab/test_fab_auth_manager.py @@ -61,43 +61,18 @@ IS_AUTHORIZED_METHODS_SIMPLE = { @pytest.fixture def auth_manager(): - return FabAuthManager(None, None) + return FabAuthManager(None) @pytest.fixture def auth_manager_with_appbuilder(): flask_app = Flask(__name__) appbuilder = init_appbuilder(flask_app) - return FabAuthManager(flask_app, appbuilder) + return FabAuthManager(appbuilder) @pytest.mark.db_test class TestFabAuthManager: - @pytest.mark.parametrize( - "id,first_name,last_name,username,email,expected", - [ - (1, "First", "Last", None, None, "1"), - (1, None, None, None, None, "1"), - (1, "First", "Last", "user", None, "user"), - (1, "First", "Last", "user", "email", "user"), - (1, None, None, None, "email", "email"), - (1, "First", "Last", None, "email", "email"), - ], - ) - @mock.patch.object(FabAuthManager, "get_user") - def test_get_user_name( - self, mock_get_user, id, first_name, last_name, username, email, expected, auth_manager - ): - user = User() - user.id = id - user.first_name = first_name - user.last_name = last_name - user.username = username - user.email = email - mock_get_user.return_value = user - - assert auth_manager.get_user_name() == expected - @pytest.mark.parametrize( "id,first_name,last_name,username,email,expected", [ @@ -130,15 +105,6 @@ class TestFabAuthManager: assert auth_manager.get_user() == user - @mock.patch.object(FabAuthManager, "get_user") - def test_get_user_id(self, mock_get_user, auth_manager): - user_id = "test" - user = Mock() - user.get_id.return_value = user_id - mock_get_user.return_value = user - - assert auth_manager.get_user_id() == user_id - @mock.patch.object(FabAuthManager, "get_user") def test_is_logged_in(self, mock_get_user, auth_manager): user = Mock() diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 9b1b5659f4..832ae50d2a 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -45,18 +45,9 @@ if TYPE_CHECKING: class EmptyAuthManager(BaseAuthManager): - def get_user_display_name(self) -> str: - raise NotImplementedError() - - def get_user_name(self) -> str: - raise NotImplementedError() - def get_user(self) -> BaseUser: raise NotImplementedError() - def get_user_id(self) -> str: - raise NotImplementedError() - def is_authorized_configuration( self, *, @@ -115,20 +106,17 @@ class EmptyAuthManager(BaseAuthManager): def get_url_logout(self) -> str: raise NotImplementedError() - def get_url_user_profile(self) -> str | None: - raise NotImplementedError() - @pytest.fixture def auth_manager(): - return EmptyAuthManager(None, None) + return EmptyAuthManager(None) @pytest.fixture def auth_manager_with_appbuilder(): flask_app = Flask(__name__) appbuilder = init_appbuilder(flask_app) - return EmptyAuthManager(flask_app, appbuilder) + return EmptyAuthManager(appbuilder) class TestBaseAuthManager: @@ -138,7 +126,37 @@ class TestBaseAuthManager: def test_get_api_endpoints_return_none(self, auth_manager): assert auth_manager.get_api_endpoints() is None - def test_is_authorized_custom_view_throws_exception(self, auth_manager): + def test_get_user_name(self, auth_manager): + user = Mock() + user.get_name.return_value = "test_username" + auth_manager.get_user = MagicMock(return_value=user) + result = auth_manager.get_user_name() + assert result == "test_username" + + def test_get_user_name_when_not_logged_in(self, auth_manager): + auth_manager.get_user = MagicMock(return_value=None) + with pytest.raises(AirflowException): + auth_manager.get_user_name() + + def test_get_user_display_name_return_user_name(self, auth_manager): + auth_manager.get_user_name = MagicMock(return_value="test_user") + assert auth_manager.get_user_display_name() == "test_user" + + def test_get_user_id_return_user_id(self, auth_manager): + user = Mock() + user.get_id = MagicMock(return_value="test_user") + auth_manager.get_user = MagicMock(return_value=user) + assert auth_manager.get_user_id() == "test_user" + + def test_get_user_id_raise_exception_when_no_user(self, auth_manager): + auth_manager.get_user = MagicMock(return_value=None) + with pytest.raises(AirflowException, match="The user must be signed in."): + auth_manager.get_user_id() + + def test_get_url_user_profile_return_none(self, auth_manager): + assert auth_manager.get_url_user_profile() is None + + def test_is_authorized_custom_view_raise_exception(self, auth_manager): with pytest.raises(AirflowException, match="The resource `.*` does not exist in the environment."): auth_manager.is_authorized_custom_view( fab_action_name=permissions.ACTION_CAN_READ, diff --git a/airflow/auth/managers/models/base_user.py b/tests/providers/amazon/aws/auth_manager/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to tests/providers/amazon/aws/auth_manager/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/tests/providers/amazon/aws/auth_manager/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/airflow/auth/managers/models/base_user.py b/tests/providers/amazon/aws/auth_manager/security_manager/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to tests/providers/amazon/aws/auth_manager/security_manager/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/tests/providers/amazon/aws/auth_manager/security_manager/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py new file mode 100644 index 0000000000..cfceb3c96c --- /dev/null +++ b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +from flask import Flask + +from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( + AwsSecurityManagerOverride, +) +from airflow.www.extensions.init_appbuilder import init_appbuilder + + +@pytest.fixture +def appbuilder(): + flask_app = Flask(__name__) + return init_appbuilder(flask_app) + + +@pytest.fixture +def override(appbuilder): + return AwsSecurityManagerOverride(appbuilder) + + +@pytest.mark.db_test +class TestAwsSecurityManagerOverride: + @patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value", return_value="test" + ) + def test_register_views(self, mock_get_mandatory_value, override, appbuilder): + pytest.importorskip("onelogin") + from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews + + with patch.object(AwsAuthManagerAuthenticationViews, "idp_data"): + appbuilder.add_view_no_menu = Mock() + override.register_views() + appbuilder.add_view_no_menu.assert_called_once() + assert isinstance( + appbuilder.add_view_no_menu.call_args.args[0], AwsAuthManagerAuthenticationViews + ) diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py new file mode 100644 index 0000000000..56f23eb36a --- /dev/null +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from flask import Flask, session + +from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager +from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( + AwsSecurityManagerOverride, +) +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser +from airflow.www.extensions.init_appbuilder import init_appbuilder +from tests.test_utils.config import conf_vars + + +@pytest.fixture +def auth_manager(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + } + ): + return AwsAuthManager(None) + + +@pytest.fixture +def auth_manager_with_appbuilder(): + flask_app = Flask(__name__) + appbuilder = init_appbuilder(flask_app) + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + } + ): + return AwsAuthManager(appbuilder) + + +@pytest.fixture +def test_user(): + return AwsAuthManagerUser(user_id="test_user_id", groups=[], username="test_username") + + +class TestAwsAuthManager: + @pytest.mark.db_test + @patch.object(AwsAuthManager, "is_logged_in") + def test_get_user(self, mock_is_logged_in, auth_manager, app, test_user): + mock_is_logged_in.return_value = True + + with app.test_request_context(): + session["aws_user"] = test_user + result = auth_manager.get_user() + + assert result == test_user + + @patch.object(AwsAuthManager, "is_logged_in") + def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_manager): + mock_is_logged_in.return_value = False + result = auth_manager.get_user() + + assert result is None + + @pytest.mark.db_test + def test_is_logged_in(self, auth_manager, app, test_user): + with app.test_request_context(): + session["aws_user"] = test_user + result = auth_manager.is_logged_in() + + assert result + + @pytest.mark.db_test + def test_is_logged_in_return_false_when_no_user_in_session(self, auth_manager, app, test_user): + with app.test_request_context(): + result = auth_manager.is_logged_in() + + assert result is False + + @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") + def test_get_url_login(self, mock_url_for, auth_manager): + auth_manager.get_url_login() + mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.login") + + @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") + def test_get_url_logout(self, mock_url_for, auth_manager): + auth_manager.get_url_logout() + mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.logout") + + @pytest.mark.db_test + def test_security_manager_return_default_security_manager(self, auth_manager_with_appbuilder): + assert isinstance(auth_manager_with_appbuilder.security_manager, AwsSecurityManagerOverride) diff --git a/airflow/auth/managers/models/base_user.py b/tests/providers/amazon/aws/auth_manager/test_constants.py similarity index 64% copy from airflow/auth/managers/models/base_user.py copy to tests/providers/amazon/aws/auth_manager/test_constants.py index 7cb45d8389..3b273d9672 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/tests/providers/amazon/aws/auth_manager/test_constants.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,17 +16,19 @@ # under the License. from __future__ import annotations -from abc import abstractmethod +from airflow.providers.amazon.aws.auth_manager.constants import ( + CONF_ENABLE_KEY, + CONF_SAML_METADATA_URL_KEY, + CONF_SECTION_NAME, +) -class BaseUser: - """User model interface.""" +class TestAwsAuthManagerConstants: + def test_conf_section_name(self): + assert CONF_SECTION_NAME == "aws_auth_manager" - @property - @abstractmethod - def is_active(self) -> bool: - ... + def test_conf_saml_metadata_url_key(self): + assert CONF_SAML_METADATA_URL_KEY == "saml_metadata_url" - @abstractmethod - def get_id(self) -> str: - ... + def test_conf_enable_key(self): + assert CONF_ENABLE_KEY == "enable" diff --git a/airflow/auth/managers/models/base_user.py b/tests/providers/amazon/aws/auth_manager/test_user.py similarity index 55% copy from airflow/auth/managers/models/base_user.py copy to tests/providers/amazon/aws/auth_manager/test_user.py index 7cb45d8389..b518c2a831 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/tests/providers/amazon/aws/auth_manager/test_user.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,17 +16,28 @@ # under the License. from __future__ import annotations -from abc import abstractmethod +import pytest + +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser + + +@pytest.fixture() +def user(): + return AwsAuthManagerUser(user_id="user_id", groups=[]) + +class TestAwsAuthManagerUser: + def test_get_id(self, user): + assert user.get_id() == "user_id" -class BaseUser: - """User model interface.""" + def test_get_name_with_username(self, user): + user.username = "username" + assert user.get_name() == "username" - @property - @abstractmethod - def is_active(self) -> bool: - ... + def test_get_name_with_email(self, user): + user.email = "email" + assert user.get_name() == "email" - @abstractmethod - def get_id(self) -> str: - ... + def test_get_name_with_user_id(self, user): + user.user_id = "user_id" + assert user.get_name() == "user_id" diff --git a/airflow/auth/managers/models/base_user.py b/tests/providers/amazon/aws/auth_manager/views/__init__.py similarity index 75% copy from airflow/auth/managers/models/base_user.py copy to tests/providers/amazon/aws/auth_manager/views/__init__.py index 7cb45d8389..13a83393a9 100644 --- a/airflow/auth/managers/models/base_user.py +++ b/tests/providers/amazon/aws/auth_manager/views/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,19 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from abc import abstractmethod - - -class BaseUser: - """User model interface.""" - - @property - @abstractmethod - def is_active(self) -> bool: - ... - - @abstractmethod - def get_id(self) -> str: - ... diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py b/tests/providers/amazon/aws/auth_manager/views/test_auth.py new file mode 100644 index 0000000000..84fdbc373c --- /dev/null +++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +from flask import session, url_for + +from airflow.exceptions import AirflowException +from airflow.www import app as application +from tests.test_utils.config import conf_vars + +pytest.importorskip("onelogin") + +SAML_METADATA_URL = "/saml/metadata" +SAML_METADATA_PARSED = { + "idp": { + "entityId": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "singleSignOnService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "singleLogoutService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "x509cert": "<cert>", + }, + "security": {"authnRequestsSigned": False}, + "sp": {"NameIDFormat": "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"}, +} + + +@pytest.fixture() +def aws_app(): + def factory(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + return application.create_app(testing=True) + + return factory() + + +@pytest.mark.db_test +class TestAwsAuthManagerAuthenticationViews: + def test_login_redirect_to_identity_center(self, aws_app): + with aws_app.test_client() as client: + response = client.get("/login") + assert response.status_code == 302 + assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/assertion/") + + def test_logout_redirect_to_identity_center(self, aws_app): + with aws_app.test_client() as client: + response = client.get("/logout") + assert response.status_code == 302 + assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/logout/") + + def test_login_metadata_return_xml_file(self, aws_app): + with aws_app.test_client() as client: + response = client.get("/login_metadata") + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/xml" + + def test_login_callback_set_user_in_session(self): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser, patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + ) as mock_init_saml_auth: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + + auth = Mock() + auth.is_authenticated.return_value = True + auth.get_nameid.return_value = "user_id" + auth.get_attributes.return_value = { + "id": ["1"], + "groups": ["group_1", "group_2"], + "email": ["email"], + } + mock_init_saml_auth.return_value = auth + app = application.create_app(testing=True) + with app.test_client() as client: + response = client.get("/login_callback") + assert response.status_code == 302 + assert response.location == url_for("Airflow.index") + assert session["aws_user"] is not None + assert session["aws_user"].get_id() == "1" + assert session["aws_user"].get_name() == "user_id" + + def test_login_callback_raise_exception_if_errors(self): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser, patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + ) as mock_init_saml_auth: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + + auth = Mock() + auth.is_authenticated.return_value = False + mock_init_saml_auth.return_value = auth + app = application.create_app(testing=True) + with app.test_client() as client: + with pytest.raises(AirflowException): + client.get("/login_callback") + + def test_logout_callback_raise_not_implemented_error(self, aws_app): + with aws_app.test_client() as client: + with pytest.raises(NotImplementedError): + client.get("/logout_callback")