This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 379b7c09d1 Implement login and logout in AWS auth manager (#35488)
379b7c09d1 is described below
commit 379b7c09d17772b9afb16998bb9cd8c1622df960
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 21 12:35:10 2023 -0500
Implement login and logout in AWS auth manager (#35488)
---
CONTRIBUTING.rst | 6 +-
Dockerfile | 4 +-
Dockerfile.ci | 4 +-
INSTALL | 6 +-
airflow/auth/managers/base_auth_manager.py | 32 +++--
airflow/auth/managers/fab/fab_auth_manager.py | 16 +--
airflow/auth/managers/fab/models/__init__.py | 3 +
airflow/auth/managers/models/base_user.py | 7 +-
.../amazon/aws/auth_manager/__init__.py} | 17 ---
.../amazon/aws/auth_manager/aws_auth_manager.py | 143 +++++++++++++++++++
.../amazon/aws/auth_manager/constants.py} | 20 +--
.../aws/auth_manager/security_manager/__init__.py} | 17 ---
.../aws_security_manager_override.py} | 25 ++--
airflow/providers/amazon/aws/auth_manager/user.py | 51 +++++++
.../amazon/aws/auth_manager/views/__init__.py} | 17 ---
.../amazon/aws/auth_manager/views/auth.py | 149 ++++++++++++++++++++
airflow/providers/amazon/provider.yaml | 25 ++++
airflow/www/extensions/init_appbuilder.py | 2 +-
airflow/www/extensions/init_auth_manager.py | 6 +-
airflow/www/extensions/init_security.py | 4 +-
airflow/www/security_manager.py | 33 +++--
airflow/www/utils.py | 5 +-
docs/apache-airflow/extra-packages-ref.rst | 2 +
docs/docker-stack/changelog.rst | 3 +
docs/spelling_wordlist.txt | 1 +
scripts/docker/install_os_dependencies.sh | 4 +-
setup.py | 7 +
.../endpoints/test_forward_to_fab_endpoint.py | 2 +-
tests/auth/managers/fab/test_fab_auth_manager.py | 38 +----
tests/auth/managers/test_base_auth_manager.py | 48 +++++--
.../providers/amazon/aws/auth_manager/__init__.py | 17 ---
.../aws/auth_manager/security_manager/__init__.py | 17 ---
.../test_aws_security_manager_override.py | 56 ++++++++
.../aws/auth_manager/test_aws_auth_manager.py | 114 +++++++++++++++
.../amazon/aws/auth_manager/test_constants.py | 23 +--
.../providers/amazon/aws/auth_manager/test_user.py | 32 +++--
.../amazon/aws/auth_manager/views/__init__.py | 17 ---
.../amazon/aws/auth_manager/views/test_auth.py | 156 +++++++++++++++++++++
38 files changed, 865 insertions(+), 264 deletions(-)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 5a5fae65ca..1114c16074 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -679,9 +679,9 @@ github_enterprise, google, google_auth, grpc, hashicorp,
hdfs, hive, http, imap,
jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure,
microsoft.mssql, microsoft.psrp,
microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openai, openfaas,
openlineage, opensearch,
opsgenie, oracle, otel, pagerduty, pandas, papermill, password, pgvector,
pinecone, pinot, plexus,
-postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment,
sendgrid, sentry, sftp,
-singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau,
tabular, telegram, trino,
-vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
+postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, saml, segment,
sendgrid, sentry,
+sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd,
tableau, tabular, telegram,
+trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
.. END EXTRAS HERE
Provider packages
diff --git a/Dockerfile b/Dockerfile
index 94794b3735..b9b358d0c7 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -97,7 +97,7 @@ function get_dev_apt_deps() {
DEV_APT_DEPS="apt-transport-https apt-utils build-essential
ca-certificates dirmngr \
freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils
libffi-dev libgeos-dev \
libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev
libsasl2-modules \
-libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \
+libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client
pkgconf sasl2-bin \
software-properties-common sqlite3 sudo unixodbc unixodbc-dev"
export DEV_APT_DEPS
fi
@@ -123,7 +123,7 @@ function get_runtime_apt_deps() {
if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then
RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \
curl dumb-init freetds-bin gosu krb5-user libgeos-dev \
-ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \
+ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales
${debian_version_apt_deps} \
lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo
unixodbc"
export RUNTIME_APT_DEPS
fi
diff --git a/Dockerfile.ci b/Dockerfile.ci
index 81a9d424a0..c3e5e2e3a0 100644
--- a/Dockerfile.ci
+++ b/Dockerfile.ci
@@ -57,7 +57,7 @@ function get_dev_apt_deps() {
DEV_APT_DEPS="apt-transport-https apt-utils build-essential
ca-certificates dirmngr \
freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils
libffi-dev libgeos-dev \
libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev
libsasl2-modules \
-libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \
+libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client
pkgconf sasl2-bin \
software-properties-common sqlite3 sudo unixodbc unixodbc-dev"
export DEV_APT_DEPS
fi
@@ -83,7 +83,7 @@ function get_runtime_apt_deps() {
if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then
RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \
curl dumb-init freetds-bin gosu krb5-user libgeos-dev \
-ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \
+ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales
${debian_version_apt_deps} \
lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo
unixodbc"
export RUNTIME_APT_DEPS
fi
diff --git a/INSTALL b/INSTALL
index 7742c006bc..0ac0ac33d0 100644
--- a/INSTALL
+++ b/INSTALL
@@ -105,9 +105,9 @@ github_enterprise, google, google_auth, grpc, hashicorp,
hdfs, hive, http, imap,
jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure,
microsoft.mssql, microsoft.psrp,
microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openai, openfaas,
openlineage, opensearch,
opsgenie, oracle, otel, pagerduty, pandas, papermill, password, pgvector,
pinecone, pinot, plexus,
-postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, segment,
sendgrid, sentry, sftp,
-singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd, tableau,
tabular, telegram, trino,
-vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
+postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, saml, segment,
sendgrid, sentry,
+sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd,
tableau, tabular, telegram,
+trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
# END EXTRAS HERE
# For installing Airflow in development environments - see CONTRIBUTING.rst
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index c158ea8481..f50e40082b 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -33,7 +33,6 @@ from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
from connexion import FlaskApi
- from flask import Flask
from sqlalchemy.orm import Session
from airflow.auth.managers.models.base_user import BaseUser
@@ -64,11 +63,12 @@ class BaseAuthManager(LoggingMixin):
Class to derive in order to implement concrete auth managers.
Auth managers are responsible for any user management related operation
such as login, logout, authz, ...
+
+ :param appbuilder: the flask app builder
"""
- def __init__(self, app: Flask, appbuilder: AirflowAppBuilder) -> None:
+ def __init__(self, appbuilder: AirflowAppBuilder) -> None:
super().__init__()
- self.app = app
self.appbuilder = appbuilder
@staticmethod
@@ -83,21 +83,29 @@ class BaseAuthManager(LoggingMixin):
"""Return API endpoint(s) definition for the auth manager."""
return None
- @abstractmethod
def get_user_name(self) -> str:
"""Return the username associated to the user in session."""
+ user = self.get_user()
+ if not user:
+ self.log.error("Calling 'get_user_name()' but the user is not
signed in.")
+ raise AirflowException("The user must be signed in.")
+ return user.get_name()
- @abstractmethod
def get_user_display_name(self) -> str:
"""Return the user's display name associated to the user in session."""
+ return self.get_user_name()
@abstractmethod
- def get_user(self) -> BaseUser:
+ def get_user(self) -> BaseUser | None:
"""Return the user associated to the user in session."""
- @abstractmethod
def get_user_id(self) -> str:
"""Return the user ID associated to the user in session."""
+ user = self.get_user()
+ if not user:
+ self.log.error("Calling 'get_user_id()' but the user is not signed
in.")
+ raise AirflowException("The user must be signed in.")
+ return str(user.get_id())
def init(self) -> None:
"""
@@ -384,9 +392,13 @@ class BaseAuthManager(LoggingMixin):
def get_url_logout(self) -> str:
"""Return the logout page url."""
- @abstractmethod
def get_url_user_profile(self) -> str | None:
- """Return the url to a page displaying info about the current user."""
+ """
+ Return the url to a page displaying info about the current user.
+
+ By default, return None.
+ """
+ return None
@cached_property
def security_manager(self) -> AirflowSecurityManagerV2:
@@ -394,7 +406,7 @@ class BaseAuthManager(LoggingMixin):
Return the security manager.
By default, Airflow comes with the default security manager
- airflow.www.security_manager.AirflowSecurityManagerV2. The auth
manager might need to extend this
+ ``airflow.www.security_manager.AirflowSecurityManagerV2``. The auth
manager might need to extend this
default security manager for its own purposes.
By default, return the default AirflowSecurityManagerV2.
diff --git a/airflow/auth/managers/fab/fab_auth_manager.py
b/airflow/auth/managers/fab/fab_auth_manager.py
index 6305406d2a..5224509039 100644
--- a/airflow/auth/managers/fab/fab_auth_manager.py
+++ b/airflow/auth/managers/fab/fab_auth_manager.py
@@ -170,26 +170,12 @@ class FabAuthManager(BaseAuthManager):
last_name = user.last_name.strip() if isinstance(user.last_name, str)
else ""
return f"{first_name} {last_name}".strip()
- def get_user_name(self) -> str:
- """
- Return the username associated to the user in session.
-
- For backward compatibility reasons, the username in FAB auth manager
can be any of username,
- email, or the database user ID.
- """
- user = self.get_user()
- return user.username or user.email or self.get_user_id()
-
def get_user(self) -> User:
"""Return the user associated to the user in session."""
from flask_login import current_user
return current_user
- def get_user_id(self) -> str:
- """Return the user ID associated to the user in session."""
- return str(self.get_user().get_id())
-
def init(self) -> None:
"""Run operations when Airflow is initializing."""
self._sync_appbuilder_roles()
@@ -351,7 +337,7 @@ class FabAuthManager(BaseAuthManager):
from airflow.auth.managers.fab.security_manager.override import
FabAirflowSecurityManagerOverride
from airflow.www.security import AirflowSecurityManager
- sm_from_config = self.app.config.get("SECURITY_MANAGER_CLASS")
+ sm_from_config =
self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS")
if sm_from_config:
if not issubclass(sm_from_config, AirflowSecurityManager):
raise Exception(
diff --git a/airflow/auth/managers/fab/models/__init__.py
b/airflow/auth/managers/fab/models/__init__.py
index 28de7840ec..1fe967d8c7 100644
--- a/airflow/auth/managers/fab/models/__init__.py
+++ b/airflow/auth/managers/fab/models/__init__.py
@@ -228,6 +228,9 @@ class User(Model, BaseUser):
def get_id(self):
return self.id
+ def get_name(self) -> str:
+ return self.username or self.email or self.user_id
+
def get_full_name(self):
return f"{self.first_name} {self.last_name}"
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/auth/managers/models/base_user.py
index 7cb45d8389..93163ee06d 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/airflow/auth/managers/models/base_user.py
@@ -24,10 +24,13 @@ class BaseUser:
"""User model interface."""
@property
- @abstractmethod
def is_active(self) -> bool:
- ...
+ return True
@abstractmethod
def get_id(self) -> str:
...
+
+ @abstractmethod
+ def get_name(self) -> str:
+ ...
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/providers/amazon/aws/auth_manager/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to airflow/providers/amazon/aws/auth_manager/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/airflow/providers/amazon/aws/auth_manager/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
new file mode 100644
index 0000000000..fcf3caacb9
--- /dev/null
+++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING
+
+from flask import session, url_for
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.providers.amazon.aws.auth_manager.constants import (
+ CONF_ENABLE_KEY,
+ CONF_SECTION_NAME,
+)
+from
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
import (
+ AwsSecurityManagerOverride,
+)
+
+try:
+ from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "Failed to import BaseUser. This feature is only available in Airflow
versions >= 2.8.0"
+ )
+
+if TYPE_CHECKING:
+ from airflow.auth.managers.models.base_user import BaseUser
+ from airflow.auth.managers.models.resource_details import (
+ AccessView,
+ ConfigurationDetails,
+ ConnectionDetails,
+ DagAccessEntity,
+ DagDetails,
+ DatasetDetails,
+ PoolDetails,
+ VariableDetails,
+ )
+ from airflow.providers.amazon.aws.auth_manager.user import
AwsAuthManagerUser
+ from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
+
+
+class AwsAuthManager(BaseAuthManager):
+ """
+ AWS auth manager.
+
+ Leverages AWS services such as Amazon Identity Center and Amazon Verified
Permissions to perform
+ authentication and authorization in Airflow.
+
+ :param appbuilder: the flask app builder
+ """
+
+ def __init__(self, appbuilder: AirflowAppBuilder) -> None:
+ super().__init__(appbuilder)
+ enable = conf.getboolean(CONF_SECTION_NAME, CONF_ENABLE_KEY)
+ if not enable:
+ raise NotImplementedError(
+ "The AWS auth manager is currently being built. It is not
finalized. It is not intended to be used yet."
+ )
+
+ def get_user(self) -> AwsAuthManagerUser | None:
+ return session["aws_user"] if self.is_logged_in() else None
+
+ def is_logged_in(self) -> bool:
+ return "aws_user" in session
+
+ def is_authorized_configuration(
+ self,
+ *,
+ method: ResourceMethod,
+ details: ConfigurationDetails | None = None,
+ user: BaseUser | None = None,
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_cluster_activity(self, *, method: ResourceMethod, user:
BaseUser | None = None) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_connection(
+ self,
+ *,
+ method: ResourceMethod,
+ details: ConnectionDetails | None = None,
+ user: BaseUser | None = None,
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_dag(
+ self,
+ *,
+ method: ResourceMethod,
+ access_entity: DagAccessEntity | None = None,
+ details: DagDetails | None = None,
+ user: BaseUser | None = None,
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_dataset(
+ self, *, method: ResourceMethod, details: DatasetDetails | None =
None, user: BaseUser | None = None
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_pool(
+ self, *, method: ResourceMethod, details: PoolDetails | None = None,
user: BaseUser | None = None
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_variable(
+ self, *, method: ResourceMethod, details: VariableDetails | None =
None, user: BaseUser | None = None
+ ) -> bool:
+ return self.is_logged_in()
+
+ def is_authorized_view(
+ self,
+ *,
+ access_view: AccessView,
+ user: BaseUser | None = None,
+ ) -> bool:
+ return self.is_logged_in()
+
+ def get_url_login(self, **kwargs) -> str:
+ return url_for("AwsAuthManagerAuthenticationViews.login")
+
+ def get_url_logout(self) -> str:
+ return url_for("AwsAuthManagerAuthenticationViews.logout")
+
+ @cached_property
+ def security_manager(self) -> AwsSecurityManagerOverride:
+ return AwsSecurityManagerOverride(self.appbuilder)
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/providers/amazon/aws/auth_manager/constants.py
similarity index 78%
copy from airflow/auth/managers/models/base_user.py
copy to airflow/providers/amazon/aws/auth_manager/constants.py
index 7cb45d8389..f2f9c1da07 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/airflow/providers/amazon/aws/auth_manager/constants.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
+# Configuration keys
+from __future__ import annotations
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
+CONF_SECTION_NAME = "aws_auth_manager"
+CONF_SAML_METADATA_URL_KEY = "saml_metadata_url"
+CONF_ENABLE_KEY = "enable"
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py
similarity index 52%
copy from airflow/auth/managers/models/base_user.py
copy to
airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py
index 7cb45d8389..9d6432f1c6 100644
--- a/airflow/auth/managers/models/base_user.py
+++
b/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,17 +16,21 @@
# under the License.
from __future__ import annotations
-from abc import abstractmethod
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+try:
+ from airflow.www.security_manager import AirflowSecurityManagerV2
+except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "Failed to import AirflowSecurityManagerV2. This feature is only
available in Airflow versions >= 2.8.0"
+ )
-class BaseUser:
- """User model interface."""
+class AwsSecurityManagerOverride(AirflowSecurityManagerV2):
+ """The security manager override specific to AWS auth manager."""
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
+ def register_views(self):
+ """Register views specific to AWS auth manager."""
+ from airflow.providers.amazon.aws.auth_manager.views.auth import
AwsAuthManagerAuthenticationViews
- @abstractmethod
- def get_id(self) -> str:
- ...
+ self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
diff --git a/airflow/providers/amazon/aws/auth_manager/user.py
b/airflow/providers/amazon/aws/auth_manager/user.py
new file mode 100644
index 0000000000..68699783d6
--- /dev/null
+++ b/airflow/providers/amazon/aws/auth_manager/user.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+try:
+ from airflow.auth.managers.models.base_user import BaseUser
+except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "Failed to import BaseUser. This feature is only available in Airflow
versions >= 2.8.0"
+ )
+
+
+class AwsAuthManagerUser(BaseUser):
+ """
+ User model for users managed by the AWS Auth Manager.
+
+ :param user_id: The user ID.
+ :param groups: The groups the user belongs to.
+ :param username: The username of the user.
+ :param email: The email of the user.
+ """
+
+ def __init__(
+ self, *, user_id: str, groups: list[str], username: str | None = None,
email: str | None = None
+ ) -> None:
+ self.user_id = user_id
+ self.groups = groups
+ self.username = username
+ self.email = email
+
+ def get_id(self) -> str:
+ return self.user_id
+
+ def get_name(self) -> str:
+ return self.username or self.email or self.user_id
diff --git a/airflow/auth/managers/models/base_user.py
b/airflow/providers/amazon/aws/auth_manager/views/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to airflow/providers/amazon/aws/auth_manager/views/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/airflow/providers/amazon/aws/auth_manager/views/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git a/airflow/providers/amazon/aws/auth_manager/views/auth.py
b/airflow/providers/amazon/aws/auth_manager/views/auth.py
new file mode 100644
index 0000000000..eb73a3805b
--- /dev/null
+++ b/airflow/providers/amazon/aws/auth_manager/views/auth.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from functools import cached_property
+
+from flask import make_response, redirect, request, session, url_for
+from flask_appbuilder import expose
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.auth_manager.constants import
CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
+from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
+from airflow.www.app import csrf
+from airflow.www.views import AirflowBaseView
+
+try:
+ from onelogin.saml2.auth import OneLogin_Saml2_Auth
+ from onelogin.saml2.idp_metadata_parser import
OneLogin_Saml2_IdPMetadataParser
+except ImportError:
+ raise ImportError(
+ "AWS auth manager requires the python3-saml library but it is not
installed by default. "
+ "Please install the python3-saml library by running: "
+ "pip install apache-airflow-providers-amazon[python3-saml]"
+ )
+
+
+class AwsAuthManagerAuthenticationViews(AirflowBaseView):
+ """
+ Views specific to AWS auth manager authentication mechanism.
+
+ Some code below is inspired from
+
https://github.com/SAML-Toolkits/python3-saml/blob/6988bdab7a203abfe8dc264992f7e350c67aef3d/demo-flask/index.py
+ """
+
+ @cached_property
+ def idp_data(self) -> dict:
+ saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME,
CONF_SAML_METADATA_URL_KEY)
+ return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url)
+
+ @expose("/login")
+ def login(self):
+ """Start login process."""
+ saml_auth = self._init_saml_auth()
+ return redirect(saml_auth.login())
+
+ @expose("/logout")
+ def logout(self):
+ """Start logout process."""
+ session.clear()
+ saml_auth = self._init_saml_auth()
+
+ return redirect(saml_auth.logout())
+
+ @csrf.exempt
+ @expose("/login_callback", methods=("GET", "POST"))
+ def login_callback(self):
+ """
+ Callback where the user is redirected to after successful login.
+
+ CSRF protection needs to be disabled otherwise the callback won't work.
+ """
+ saml_auth = self._init_saml_auth()
+ saml_auth.process_response()
+ errors = saml_auth.get_errors()
+ is_authenticated = saml_auth.is_authenticated()
+ if not is_authenticated:
+ error_reason = saml_auth.get_last_error_reason()
+ logging.error("Failed to authenticate")
+ logging.error("Errors: %s", errors)
+ logging.error("Error reason: %s", error_reason)
+ raise AirflowException(f"Failed to authenticate: {error_reason}")
+
+ attributes = saml_auth.get_attributes()
+ user = AwsAuthManagerUser(
+ user_id=attributes["id"][0],
+ groups=attributes["groups"],
+ username=saml_auth.get_nameid(),
+ email=attributes["email"][0],
+ )
+ session["aws_user"] = user
+
+ return redirect(url_for("Airflow.index"))
+
+ @csrf.exempt
+ @expose("/logout_callback", methods=("GET", "POST"))
+ def logout_callback(self):
+ raise NotImplementedError("AWS Identity center does not support SLO
(Single Logout Service)")
+
+ @expose("/login_metadata")
+ def login_metadata(self):
+ saml_auth = self._init_saml_auth()
+ settings = saml_auth.get_settings()
+ metadata = settings.get_sp_metadata()
+ errors = settings.validate_metadata(metadata)
+
+ if len(errors) == 0:
+ resp = make_response(metadata, 200)
+ resp.headers["Content-Type"] = "text/xml"
+ else:
+ resp = make_response(", ".join(errors), 500)
+ return resp
+
+ @staticmethod
+ def _prepare_flask_request() -> dict:
+ return {
+ "https": "on" if request.scheme == "https" else "off",
+ "http_host": request.host,
+ "script_name": request.path,
+ "get_data": request.args.copy(),
+ "post_data": request.form.copy(),
+ }
+
+ def _init_saml_auth(self) -> OneLogin_Saml2_Auth:
+ request_data = self._prepare_flask_request()
+ base_url = conf.get(section="webserver", key="base_url")
+ settings = {
+ # We want to keep this flag on in case of errors.
+ # It provides an error reasons, if turned off, it does not
+ "debug": True,
+ "sp": {
+ "entityId": f"{base_url}/login_metadata",
+ "assertionConsumerService": {
+ "url": f"{base_url}/login_callback",
+ "binding":
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
+ },
+ "singleLogoutService": {
+ "url": f"{base_url}/logout_callback",
+ "binding":
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+ },
+ },
+ }
+ merged_settings =
OneLogin_Saml2_IdPMetadataParser.merge_settings(settings, self.idp_data)
+ return OneLogin_Saml2_Auth(request_data, merged_settings)
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 8c0c13fbc6..6399961543 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -746,6 +746,9 @@ additional-extras:
- name: s3fs
dependencies:
- s3fs>=2023.9.2
+ - name: python3-saml
+ dependencies:
+ - python3-saml>=1.16.0
config:
aws:
@@ -897,6 +900,28 @@ config:
type: boolean
example: "True"
default: "True"
+ aws_auth_manager:
+ description: |
+ This section only applies if you are using the AwsAuthManager. In other
words, if you set
+ ``[core] auth_manager =
airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager`` in
+ Airflow's configuration.
+ options:
+ enable:
+ description: |
+ AWS auth manager is not ready to be used. Turn on this flag to use
it anyway.
+ Do that at your own risk since the AWS auth manager is not in an
usable state.
+ version_added: 8.12.0
+ type: boolean
+ example: "True"
+ default: "False"
+ saml_metadata_url:
+ description: |
+ SAML metadata XML file provided by AWS Identity Center.
+ This URL can be found in the AWS Identity Center console. Required.
+ version_added: 8.12.0
+ type: string
+ example:
"https://portal.sso.<region>.amazonaws.com/saml/metadata/XXXXXXXXXX"
+ default: ~
executors:
- airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor
diff --git a/airflow/www/extensions/init_appbuilder.py
b/airflow/www/extensions/init_appbuilder.py
index ce87fdc76c..ddb44200f9 100644
--- a/airflow/www/extensions/init_appbuilder.py
+++ b/airflow/www/extensions/init_appbuilder.py
@@ -214,7 +214,7 @@ class AirflowAppBuilder:
self._addon_managers = app.config["ADDON_MANAGERS"]
self.session = session
- auth_manager = init_auth_manager(app, self)
+ auth_manager = init_auth_manager(self)
self.sm = auth_manager.security_manager
self.bm = BabelManager(self)
self._add_global_static()
diff --git a/airflow/www/extensions/init_auth_manager.py
b/airflow/www/extensions/init_auth_manager.py
index 63ec043c10..d29fb6944e 100644
--- a/airflow/www/extensions/init_auth_manager.py
+++ b/airflow/www/extensions/init_auth_manager.py
@@ -22,8 +22,6 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
if TYPE_CHECKING:
- from flask import Flask
-
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
@@ -47,7 +45,7 @@ def get_auth_manager_cls() -> type[BaseAuthManager]:
return auth_manager_cls
-def init_auth_manager(app: Flask, appbuilder: AirflowAppBuilder) ->
BaseAuthManager:
+def init_auth_manager(appbuilder: AirflowAppBuilder) -> BaseAuthManager:
"""
Initialize the auth manager.
@@ -55,7 +53,7 @@ def init_auth_manager(app: Flask, appbuilder:
AirflowAppBuilder) -> BaseAuthMana
"""
global auth_manager
auth_manager_cls = get_auth_manager_cls()
- auth_manager = auth_manager_cls(app, appbuilder)
+ auth_manager = auth_manager_cls(appbuilder)
return auth_manager
diff --git a/airflow/www/extensions/init_security.py
b/airflow/www/extensions/init_security.py
index 21b71e8f84..a7739e3231 100644
--- a/airflow/www/extensions/init_security.py
+++ b/airflow/www/extensions/init_security.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
from importlib import import_module
-from flask import g, redirect, request
+from flask import redirect, request
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException
@@ -72,5 +72,5 @@ def init_check_user_active(app):
url_logout = get_auth_manager().get_url_logout()
if request.path == url_logout:
return
- if get_auth_manager().is_logged_in() and not g.user.is_active:
+ if get_auth_manager().is_logged_in() and not
get_auth_manager().get_user().is_active:
return redirect(url_logout)
diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py
index 35e2a63e64..78d8ef00fe 100644
--- a/airflow/www/security_manager.py
+++ b/airflow/www/security_manager.py
@@ -112,6 +112,10 @@ class AirflowSecurityManagerV2(LoggingMixin):
limiter.init_app(self.appbuilder.get_app)
return limiter
+ def register_views(self):
+ """Allow auth managers to register their own views. By default, do
nothing."""
+ pass
+
def has_access(
self, action_name: str, resource_name: str, user=None, resource_pk:
str | None = None
) -> bool:
@@ -168,18 +172,6 @@ class AirflowSecurityManagerV2(LoggingMixin):
cost=limit.cost,
)(baseview.blueprint)
- def add_permissions_view(self, base_action_names, resource_name):
- raise NotImplementedError("Sync FAB permissions is only available with
the FAB auth manager")
-
- def add_permissions_menu(self, resource_name):
- raise NotImplementedError("Sync FAB permissions is only available with
the FAB auth manager")
-
- def get_action(self, name: str) -> Action:
- raise NotImplementedError("Only available when FAB auth manager is
used")
-
- def get_resource(self, name: str) -> Resource:
- raise NotImplementedError("Only available when FAB auth manager is
used")
-
@cached_property
@provide_session
def _auth_manager_is_authorized_map(
@@ -355,3 +347,20 @@ class AirflowSecurityManagerV2(LoggingMixin):
return lambda action, resource_pk, user: any(
self._get_auth_manager_is_authorized_method(fab_resource_name=item) for item in
items
)
+
+ """
+ The following methods are specific to FAB auth manager. They still need to
be "present" in the main
+ security manager class, but they do nothing.
+ """
+
+ def get_action(self, name: str) -> Action:
+ raise NotImplementedError("Only available if FAB auth manager is used")
+
+ def get_resource(self, name: str) -> Resource:
+ raise NotImplementedError("Only available if FAB auth manager is used")
+
+ def add_permissions_view(self, base_action_names, resource_name):
+ pass
+
+ def add_permissions_menu(self, resource_name):
+ pass
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 5fe8a7ccb7..16d0950989 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -940,10 +940,7 @@ class UIAlert:
if self.roles:
current_user = get_auth_manager().get_user()
if current_user is not None:
- if not hasattr(current_user, "roles"):
- # If the user does not contain "roles" in its model,
return False
- return False
- user_roles = {r.name for r in current_user.roles}
+ user_roles = {r.name for r in getattr(current_user, "roles",
[])}
elif "AUTH_ROLE_PUBLIC" in appbuilder.get_app.config:
# If the current_user is anonymous, assign AUTH_ROLE_PUBLIC
role (if it exists) to them
user_roles = {appbuilder.get_app.config["AUTH_ROLE_PUBLIC"]}
diff --git a/docs/apache-airflow/extra-packages-ref.rst
b/docs/apache-airflow/extra-packages-ref.rst
index 3060d7c2db..32e6d892bc 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -70,6 +70,8 @@ python dependencies for the provided package.
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| s3fs | ``pip install 'apache-airflow[s3fs]'`` |
Support for S3 as Airflow FS |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
+| saml | ``pip install 'apache-airflow[saml]'`` |
Support for SAML authentication in Airflow |
++---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| statsd | ``pip install 'apache-airflow[statsd]'`` |
Needed by StatsD metrics |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| virtualenv | ``pip install 'apache-airflow[virtualenv]'`` |
Running python tasks in local virtualenv |
diff --git a/docs/docker-stack/changelog.rst b/docs/docker-stack/changelog.rst
index 0f752cb189..ad71d49221 100644
--- a/docs/docker-stack/changelog.rst
+++ b/docs/docker-stack/changelog.rst
@@ -39,6 +39,9 @@ Airflow 2.8
* 2.8.0
+ * Add ``libxmlsec1`` and ``libxmlsec1-dev`` libraries to dev PROD image and
``libxmlsec1`` library to runtime PROD
+ image as it is required by ``python3-saml`` library.
+
* The image is based on ``Debian Bookworm`` in 2.8.0 rather than ``Debian
Bullseye``. This might cause some
problems when building custom images. You are advised to make sure your
system level dependencies are
working with ``Debian Bookworm``. While all reference images of Airflow
2.8.0 are built on ``Debian Bookworm``,
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 02380d9884..56198798b4 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1356,6 +1356,7 @@ Sagemaker
sagemaker
salesforce
samesite
+saml
sanitization
sas
Sasl
diff --git a/scripts/docker/install_os_dependencies.sh
b/scripts/docker/install_os_dependencies.sh
index cb6cb692fb..318add1435 100644
--- a/scripts/docker/install_os_dependencies.sh
+++ b/scripts/docker/install_os_dependencies.sh
@@ -39,7 +39,7 @@ function get_dev_apt_deps() {
DEV_APT_DEPS="apt-transport-https apt-utils build-essential
ca-certificates dirmngr \
freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils
libffi-dev libgeos-dev \
libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev
libsasl2-modules \
-libssl-dev locales lsb-release openssh-client pkgconf sasl2-bin \
+libssl-dev libxmlsec1 libxmlsec1-dev locales lsb-release openssh-client
pkgconf sasl2-bin \
software-properties-common sqlite3 sudo unixodbc unixodbc-dev"
export DEV_APT_DEPS
fi
@@ -65,7 +65,7 @@ function get_runtime_apt_deps() {
if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then
RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \
curl dumb-init freetds-bin gosu krb5-user libgeos-dev \
-ldap-utils libsasl2-2 libsasl2-modules locales ${debian_version_apt_deps} \
+ldap-utils libsasl2-2 libsasl2-modules libxmlsec1 locales
${debian_version_apt_deps} \
lsb-release openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo
unixodbc"
export RUNTIME_APT_DEPS
fi
diff --git a/setup.py b/setup.py
index 46a7e68c5a..bc624e21fa 100644
--- a/setup.py
+++ b/setup.py
@@ -527,6 +527,11 @@ s3fs = [
"s3fs>=2023.9.2",
]
+saml = [
+ # This is required for support of SAML which might be used by some
providers (e.g. Amazon)
+ "python3-saml>=1.16.0",
+]
+
def get_provider_dependencies(provider_name: str) -> list[str]:
if provider_name not in PROVIDER_DEPENDENCIES:
@@ -554,6 +559,7 @@ devel = get_unique_dependency_list(
pandas,
password,
s3fs,
+ saml,
]
)
@@ -600,6 +606,7 @@ CORE_EXTRAS_DEPENDENCIES: dict[str, list[str]] = {
"password": password,
"rabbitmq": rabbitmq,
"s3fs": s3fs,
+ "saml": saml,
"sentry": sentry,
"statsd": statsd,
"virtualenv": virtualenv,
diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
index 5560a4d877..4b8294eb6d 100644
--- a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
+++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
@@ -130,7 +130,7 @@ class TestFABforwarding:
class TestFABRoleForwarding(TestFABforwarding):
@mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager")
def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager):
- mock_get_auth_manager.return_value = BaseAuthManager(self.app,
self.app.appbuilder)
+ mock_get_auth_manager.return_value =
BaseAuthManager(self.app.appbuilder)
response = self.client.get("api/v1/roles",
environ_overrides={"REMOTE_USER": "test"})
assert response.status_code == 400
assert (
diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py
b/tests/auth/managers/fab/test_fab_auth_manager.py
index 934293606a..74d20ab13b 100644
--- a/tests/auth/managers/fab/test_fab_auth_manager.py
+++ b/tests/auth/managers/fab/test_fab_auth_manager.py
@@ -61,43 +61,18 @@ IS_AUTHORIZED_METHODS_SIMPLE = {
@pytest.fixture
def auth_manager():
- return FabAuthManager(None, None)
+ return FabAuthManager(None)
@pytest.fixture
def auth_manager_with_appbuilder():
flask_app = Flask(__name__)
appbuilder = init_appbuilder(flask_app)
- return FabAuthManager(flask_app, appbuilder)
+ return FabAuthManager(appbuilder)
@pytest.mark.db_test
class TestFabAuthManager:
- @pytest.mark.parametrize(
- "id,first_name,last_name,username,email,expected",
- [
- (1, "First", "Last", None, None, "1"),
- (1, None, None, None, None, "1"),
- (1, "First", "Last", "user", None, "user"),
- (1, "First", "Last", "user", "email", "user"),
- (1, None, None, None, "email", "email"),
- (1, "First", "Last", None, "email", "email"),
- ],
- )
- @mock.patch.object(FabAuthManager, "get_user")
- def test_get_user_name(
- self, mock_get_user, id, first_name, last_name, username, email,
expected, auth_manager
- ):
- user = User()
- user.id = id
- user.first_name = first_name
- user.last_name = last_name
- user.username = username
- user.email = email
- mock_get_user.return_value = user
-
- assert auth_manager.get_user_name() == expected
-
@pytest.mark.parametrize(
"id,first_name,last_name,username,email,expected",
[
@@ -130,15 +105,6 @@ class TestFabAuthManager:
assert auth_manager.get_user() == user
- @mock.patch.object(FabAuthManager, "get_user")
- def test_get_user_id(self, mock_get_user, auth_manager):
- user_id = "test"
- user = Mock()
- user.get_id.return_value = user_id
- mock_get_user.return_value = user
-
- assert auth_manager.get_user_id() == user_id
-
@mock.patch.object(FabAuthManager, "get_user")
def test_is_logged_in(self, mock_get_user, auth_manager):
user = Mock()
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index 9b1b5659f4..832ae50d2a 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -45,18 +45,9 @@ if TYPE_CHECKING:
class EmptyAuthManager(BaseAuthManager):
- def get_user_display_name(self) -> str:
- raise NotImplementedError()
-
- def get_user_name(self) -> str:
- raise NotImplementedError()
-
def get_user(self) -> BaseUser:
raise NotImplementedError()
- def get_user_id(self) -> str:
- raise NotImplementedError()
-
def is_authorized_configuration(
self,
*,
@@ -115,20 +106,17 @@ class EmptyAuthManager(BaseAuthManager):
def get_url_logout(self) -> str:
raise NotImplementedError()
- def get_url_user_profile(self) -> str | None:
- raise NotImplementedError()
-
@pytest.fixture
def auth_manager():
- return EmptyAuthManager(None, None)
+ return EmptyAuthManager(None)
@pytest.fixture
def auth_manager_with_appbuilder():
flask_app = Flask(__name__)
appbuilder = init_appbuilder(flask_app)
- return EmptyAuthManager(flask_app, appbuilder)
+ return EmptyAuthManager(appbuilder)
class TestBaseAuthManager:
@@ -138,7 +126,37 @@ class TestBaseAuthManager:
def test_get_api_endpoints_return_none(self, auth_manager):
assert auth_manager.get_api_endpoints() is None
- def test_is_authorized_custom_view_throws_exception(self, auth_manager):
+ def test_get_user_name(self, auth_manager):
+ user = Mock()
+ user.get_name.return_value = "test_username"
+ auth_manager.get_user = MagicMock(return_value=user)
+ result = auth_manager.get_user_name()
+ assert result == "test_username"
+
+ def test_get_user_name_when_not_logged_in(self, auth_manager):
+ auth_manager.get_user = MagicMock(return_value=None)
+ with pytest.raises(AirflowException):
+ auth_manager.get_user_name()
+
+ def test_get_user_display_name_return_user_name(self, auth_manager):
+ auth_manager.get_user_name = MagicMock(return_value="test_user")
+ assert auth_manager.get_user_display_name() == "test_user"
+
+ def test_get_user_id_return_user_id(self, auth_manager):
+ user = Mock()
+ user.get_id = MagicMock(return_value="test_user")
+ auth_manager.get_user = MagicMock(return_value=user)
+ assert auth_manager.get_user_id() == "test_user"
+
+ def test_get_user_id_raise_exception_when_no_user(self, auth_manager):
+ auth_manager.get_user = MagicMock(return_value=None)
+ with pytest.raises(AirflowException, match="The user must be signed
in."):
+ auth_manager.get_user_id()
+
+ def test_get_url_user_profile_return_none(self, auth_manager):
+ assert auth_manager.get_url_user_profile() is None
+
+ def test_is_authorized_custom_view_raise_exception(self, auth_manager):
with pytest.raises(AirflowException, match="The resource `.*` does not
exist in the environment."):
auth_manager.is_authorized_custom_view(
fab_action_name=permissions.ACTION_CAN_READ,
diff --git a/airflow/auth/managers/models/base_user.py
b/tests/providers/amazon/aws/auth_manager/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to tests/providers/amazon/aws/auth_manager/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/tests/providers/amazon/aws/auth_manager/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git a/airflow/auth/managers/models/base_user.py
b/tests/providers/amazon/aws/auth_manager/security_manager/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to tests/providers/amazon/aws/auth_manager/security_manager/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/tests/providers/amazon/aws/auth_manager/security_manager/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git
a/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py
b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py
new file mode 100644
index 0000000000..cfceb3c96c
--- /dev/null
+++
b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+from flask import Flask
+
+from
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
import (
+ AwsSecurityManagerOverride,
+)
+from airflow.www.extensions.init_appbuilder import init_appbuilder
+
+
[email protected]
+def appbuilder():
+ flask_app = Flask(__name__)
+ return init_appbuilder(flask_app)
+
+
[email protected]
+def override(appbuilder):
+ return AwsSecurityManagerOverride(appbuilder)
+
+
[email protected]_test
+class TestAwsSecurityManagerOverride:
+ @patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value",
return_value="test"
+ )
+ def test_register_views(self, mock_get_mandatory_value, override,
appbuilder):
+ pytest.importorskip("onelogin")
+ from airflow.providers.amazon.aws.auth_manager.views.auth import
AwsAuthManagerAuthenticationViews
+
+ with patch.object(AwsAuthManagerAuthenticationViews, "idp_data"):
+ appbuilder.add_view_no_menu = Mock()
+ override.register_views()
+ appbuilder.add_view_no_menu.assert_called_once()
+ assert isinstance(
+ appbuilder.add_view_no_menu.call_args.args[0],
AwsAuthManagerAuthenticationViews
+ )
diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
new file mode 100644
index 0000000000..56f23eb36a
--- /dev/null
+++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+from flask import Flask, session
+
+from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import
AwsAuthManager
+from
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
import (
+ AwsSecurityManagerOverride,
+)
+from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
+from airflow.www.extensions.init_appbuilder import init_appbuilder
+from tests.test_utils.config import conf_vars
+
+
[email protected]
+def auth_manager():
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ }
+ ):
+ return AwsAuthManager(None)
+
+
[email protected]
+def auth_manager_with_appbuilder():
+ flask_app = Flask(__name__)
+ appbuilder = init_appbuilder(flask_app)
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ }
+ ):
+ return AwsAuthManager(appbuilder)
+
+
[email protected]
+def test_user():
+ return AwsAuthManagerUser(user_id="test_user_id", groups=[],
username="test_username")
+
+
+class TestAwsAuthManager:
+ @pytest.mark.db_test
+ @patch.object(AwsAuthManager, "is_logged_in")
+ def test_get_user(self, mock_is_logged_in, auth_manager, app, test_user):
+ mock_is_logged_in.return_value = True
+
+ with app.test_request_context():
+ session["aws_user"] = test_user
+ result = auth_manager.get_user()
+
+ assert result == test_user
+
+ @patch.object(AwsAuthManager, "is_logged_in")
+ def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in,
auth_manager):
+ mock_is_logged_in.return_value = False
+ result = auth_manager.get_user()
+
+ assert result is None
+
+ @pytest.mark.db_test
+ def test_is_logged_in(self, auth_manager, app, test_user):
+ with app.test_request_context():
+ session["aws_user"] = test_user
+ result = auth_manager.is_logged_in()
+
+ assert result
+
+ @pytest.mark.db_test
+ def test_is_logged_in_return_false_when_no_user_in_session(self,
auth_manager, app, test_user):
+ with app.test_request_context():
+ result = auth_manager.is_logged_in()
+
+ assert result is False
+
+
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
+ def test_get_url_login(self, mock_url_for, auth_manager):
+ auth_manager.get_url_login()
+
mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.login")
+
+
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
+ def test_get_url_logout(self, mock_url_for, auth_manager):
+ auth_manager.get_url_logout()
+
mock_url_for.assert_called_once_with("AwsAuthManagerAuthenticationViews.logout")
+
+ @pytest.mark.db_test
+ def test_security_manager_return_default_security_manager(self,
auth_manager_with_appbuilder):
+ assert isinstance(auth_manager_with_appbuilder.security_manager,
AwsSecurityManagerOverride)
diff --git a/airflow/auth/managers/models/base_user.py
b/tests/providers/amazon/aws/auth_manager/test_constants.py
similarity index 64%
copy from airflow/auth/managers/models/base_user.py
copy to tests/providers/amazon/aws/auth_manager/test_constants.py
index 7cb45d8389..3b273d9672 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/tests/providers/amazon/aws/auth_manager/test_constants.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,17 +16,19 @@
# under the License.
from __future__ import annotations
-from abc import abstractmethod
+from airflow.providers.amazon.aws.auth_manager.constants import (
+ CONF_ENABLE_KEY,
+ CONF_SAML_METADATA_URL_KEY,
+ CONF_SECTION_NAME,
+)
-class BaseUser:
- """User model interface."""
+class TestAwsAuthManagerConstants:
+ def test_conf_section_name(self):
+ assert CONF_SECTION_NAME == "aws_auth_manager"
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
+ def test_conf_saml_metadata_url_key(self):
+ assert CONF_SAML_METADATA_URL_KEY == "saml_metadata_url"
- @abstractmethod
- def get_id(self) -> str:
- ...
+ def test_conf_enable_key(self):
+ assert CONF_ENABLE_KEY == "enable"
diff --git a/airflow/auth/managers/models/base_user.py
b/tests/providers/amazon/aws/auth_manager/test_user.py
similarity index 55%
copy from airflow/auth/managers/models/base_user.py
copy to tests/providers/amazon/aws/auth_manager/test_user.py
index 7cb45d8389..b518c2a831 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/tests/providers/amazon/aws/auth_manager/test_user.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,17 +16,28 @@
# under the License.
from __future__ import annotations
-from abc import abstractmethod
+import pytest
+
+from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
+
+
[email protected]()
+def user():
+ return AwsAuthManagerUser(user_id="user_id", groups=[])
+
+class TestAwsAuthManagerUser:
+ def test_get_id(self, user):
+ assert user.get_id() == "user_id"
-class BaseUser:
- """User model interface."""
+ def test_get_name_with_username(self, user):
+ user.username = "username"
+ assert user.get_name() == "username"
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
+ def test_get_name_with_email(self, user):
+ user.email = "email"
+ assert user.get_name() == "email"
- @abstractmethod
- def get_id(self) -> str:
- ...
+ def test_get_name_with_user_id(self, user):
+ user.user_id = "user_id"
+ assert user.get_name() == "user_id"
diff --git a/airflow/auth/managers/models/base_user.py
b/tests/providers/amazon/aws/auth_manager/views/__init__.py
similarity index 75%
copy from airflow/auth/managers/models/base_user.py
copy to tests/providers/amazon/aws/auth_manager/views/__init__.py
index 7cb45d8389..13a83393a9 100644
--- a/airflow/auth/managers/models/base_user.py
+++ b/tests/providers/amazon/aws/auth_manager/views/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,19 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from abc import abstractmethod
-
-
-class BaseUser:
- """User model interface."""
-
- @property
- @abstractmethod
- def is_active(self) -> bool:
- ...
-
- @abstractmethod
- def get_id(self) -> str:
- ...
diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py
b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
new file mode 100644
index 0000000000..84fdbc373c
--- /dev/null
+++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+from flask import session, url_for
+
+from airflow.exceptions import AirflowException
+from airflow.www import app as application
+from tests.test_utils.config import conf_vars
+
+pytest.importorskip("onelogin")
+
+SAML_METADATA_URL = "/saml/metadata"
+SAML_METADATA_PARSED = {
+ "idp": {
+ "entityId":
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+ "singleSignOnService": {
+ "url":
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+ "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+ },
+ "singleLogoutService": {
+ "url":
"https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>",
+ "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+ },
+ "x509cert": "<cert>",
+ },
+ "security": {"authnRequestsSigned": False},
+ "sp": {"NameIDFormat":
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient"},
+}
+
+
[email protected]()
+def aws_app():
+ def factory():
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+ }
+ ):
+ with patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+ ) as mock_parser:
+ mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+ return application.create_app(testing=True)
+
+ return factory()
+
+
[email protected]_test
+class TestAwsAuthManagerAuthenticationViews:
+ def test_login_redirect_to_identity_center(self, aws_app):
+ with aws_app.test_client() as client:
+ response = client.get("/login")
+ assert response.status_code == 302
+ assert
response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/assertion/")
+
+ def test_logout_redirect_to_identity_center(self, aws_app):
+ with aws_app.test_client() as client:
+ response = client.get("/logout")
+ assert response.status_code == 302
+ assert
response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/logout/")
+
+ def test_login_metadata_return_xml_file(self, aws_app):
+ with aws_app.test_client() as client:
+ response = client.get("/login_metadata")
+ assert response.status_code == 200
+ assert response.headers["Content-Type"] == "text/xml"
+
+ def test_login_callback_set_user_in_session(self):
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+ }
+ ):
+ with patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+ ) as mock_parser, patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
+ ) as mock_init_saml_auth:
+ mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+
+ auth = Mock()
+ auth.is_authenticated.return_value = True
+ auth.get_nameid.return_value = "user_id"
+ auth.get_attributes.return_value = {
+ "id": ["1"],
+ "groups": ["group_1", "group_2"],
+ "email": ["email"],
+ }
+ mock_init_saml_auth.return_value = auth
+ app = application.create_app(testing=True)
+ with app.test_client() as client:
+ response = client.get("/login_callback")
+ assert response.status_code == 302
+ assert response.location == url_for("Airflow.index")
+ assert session["aws_user"] is not None
+ assert session["aws_user"].get_id() == "1"
+ assert session["aws_user"].get_name() == "user_id"
+
+ def test_login_callback_raise_exception_if_errors(self):
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+ }
+ ):
+ with patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+ ) as mock_parser, patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
+ ) as mock_init_saml_auth:
+ mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+
+ auth = Mock()
+ auth.is_authenticated.return_value = False
+ mock_init_saml_auth.return_value = auth
+ app = application.create_app(testing=True)
+ with app.test_client() as client:
+ with pytest.raises(AirflowException):
+ client.get("/login_callback")
+
+ def test_logout_callback_raise_not_implemented_error(self, aws_app):
+ with aws_app.test_client() as client:
+ with pytest.raises(NotImplementedError):
+ client.get("/logout_callback")