This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a6064f3a3c7 Handle the next URL after logging in in the simple auth 
manager (#44856)
a6064f3a3c7 is described below

commit a6064f3a3c74cd1c1de7a044672d60b84da4252b
Author: Vincent <[email protected]>
AuthorDate: Thu Dec 12 09:59:57 2024 -0500

    Handle the next URL after logging in in the simple auth manager (#44856)
---
 .../auth/managers/simple/simple_auth_manager.py    |  3 +-
 airflow/auth/managers/simple/views/auth.py         | 39 ++++++++++++++++++++--
 airflow/www/templates/airflow/login.html           |  2 +-
 .../managers/simple/test_simple_auth_manager.py    |  2 +-
 tests/auth/managers/simple/views/test_auth.py      | 21 +++++++++---
 5 files changed, 56 insertions(+), 11 deletions(-)

diff --git a/airflow/auth/managers/simple/simple_auth_manager.py 
b/airflow/auth/managers/simple/simple_auth_manager.py
index 48baa02e7c7..d63aa480c9d 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -122,7 +122,8 @@ class 
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
         )
 
     def get_url_login(self, **kwargs) -> str:
-        return url_for("SimpleAuthManagerAuthenticationViews.login")
+        """Return the login page url."""
+        return url_for("SimpleAuthManagerAuthenticationViews.login", 
next=kwargs.get("next_url"))
 
     def get_url_logout(self) -> str:
         return url_for("SimpleAuthManagerAuthenticationViews.logout")
diff --git a/airflow/auth/managers/simple/views/auth.py 
b/airflow/auth/managers/simple/views/auth.py
index 6e4cf0c3994..bd06661d833 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import logging
+from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
 
 from flask import redirect, request, session, url_for
 from flask_appbuilder import expose
@@ -54,7 +55,9 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         return self.render_template(
             "airflow/login.html",
             disable_nav_bar=True,
-            
login_submit_url=url_for("SimpleAuthManagerAuthenticationViews.login_submit"),
+            login_submit_url=url_for(
+                "SimpleAuthManagerAuthenticationViews.login_submit", 
next=request.args.get("next")
+            ),
             auto_refresh_interval=conf.getint("webserver", 
"auto_refresh_interval"),
             state_color_mapping=state_color_mapping,
             standalone_dag_processor=standalone_dag_processor,
@@ -72,6 +75,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         """Redirect the user to this callback after login attempt."""
         username = request.form.get("username")
         password = request.form.get("password")
+        next_url = request.args.get("next")
 
         found_users = [
             user
@@ -80,7 +84,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         ]
 
         if not username or not password or len(found_users) == 0:
-            return 
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"]))
+            return 
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"], 
next=next_url))
 
         user = SimpleAuthManagerUser(
             username=username,
@@ -96,4 +100,33 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         )
         token = 
signer.generate_signed_token(get_auth_manager().serialize_user(user))
 
-        return redirect(url_for("Airflow.index", token=token))
+        if next_url:
+            return redirect(self._get_redirect_url(next_url, token))
+        else:
+            return redirect(url_for("Airflow.index", token=token))
+
+    def _get_redirect_url(self, next_url: str, token: str) -> str:
+        if self._is_same_domain(next_url, request.url):
+            return self._add_query_params(next_url, {"token": token})
+        else:
+            return url_for("Airflow.index", token=token)
+
+    @staticmethod
+    def _is_same_domain(next_url: str, current_url: str) -> bool:
+        next_url_infos = urlsplit(next_url)
+        current_url_infos = urlsplit(current_url)
+        return (
+            current_url_infos.netloc.startswith("localhost:")
+            or (not next_url_infos.scheme or next_url_infos.scheme == 
current_url_infos.scheme)
+            and (not next_url_infos.netloc or next_url_infos.netloc == 
current_url_infos.netloc)
+        )
+
+    @staticmethod
+    def _add_query_params(url: str, params: dict) -> str:
+        url_infos = urlsplit(url)
+        existing_query = dict(parse_qsl(url_infos.query))
+        existing_query.update(params)
+        updated_query = urlencode(existing_query, doseq=True)
+        return urlunsplit(
+            (url_infos.scheme, url_infos.netloc, url_infos.path, 
updated_query, url_infos.fragment)
+        )
diff --git a/airflow/www/templates/airflow/login.html 
b/airflow/www/templates/airflow/login.html
index 5a25fb3b5f2..afeac1104a7 100644
--- a/airflow/www/templates/airflow/login.html
+++ b/airflow/www/templates/airflow/login.html
@@ -21,7 +21,7 @@
 
 {% block head_meta %}
   {{ super() }}
-  <meta name="login_submit_url" content="{{ 
url_for('SimpleAuthManagerAuthenticationViews.login_submit') }}">
+  <meta name="login_submit_url" content="{{ login_submit_url }}">
 {% endblock %}
 
 {% block messages %}
diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py 
b/tests/auth/managers/simple/test_simple_auth_manager.py
index 07289f6f002..0cc553ad420 100644
--- a/tests/auth/managers/simple/test_simple_auth_manager.py
+++ b/tests/auth/managers/simple/test_simple_auth_manager.py
@@ -95,7 +95,7 @@ class TestSimpleAuthManager:
     @patch("airflow.auth.managers.simple.simple_auth_manager.url_for")
     def test_get_url_login(self, mock_url_for, auth_manager):
         auth_manager.get_url_login()
-        
mock_url_for.assert_called_once_with("SimpleAuthManagerAuthenticationViews.login")
+        
mock_url_for.assert_called_once_with("SimpleAuthManagerAuthenticationViews.login",
 next=None)
 
     @patch("airflow.auth.managers.simple.simple_auth_manager.url_for")
     def test_get_url_logout(self, mock_url_for, auth_manager):
diff --git a/tests/auth/managers/simple/views/test_auth.py 
b/tests/auth/managers/simple/views/test_auth.py
index e3e7b29e2cc..86a8be4f444 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -65,18 +65,29 @@ class TestSimpleAuthManagerAuthenticationViews:
             assert session.get("user") is None
 
     @pytest.mark.parametrize(
-        "username, password, is_successful",
-        [("test", "test", True), ("test", "test2", False), ("", "", False)],
+        "username, password, is_successful, query_params, expected_redirect",
+        [
+            ("test", "test", True, {}, None),
+            ("test", "test2", False, {}, None),
+            ("", "", False, {}, None),
+            ("test", "test", True, {"next": "next_url"}, 
"next_url?token=token"),
+        ],
     )
     @patch("airflow.auth.managers.simple.views.auth.JWTSigner")
-    def test_login_submit(self, mock_jwt_signer, simple_app, username, 
password, is_successful):
+    def test_login_submit(
+        self, mock_jwt_signer, simple_app, username, password, is_successful, 
query_params, expected_redirect
+    ):
         signer = Mock()
         signer.generate_signed_token.return_value = "token"
         mock_jwt_signer.return_value = signer
         with simple_app.test_client() as client:
-            response = client.post("/login_submit", data={"username": 
username, "password": password})
+            response = client.post(
+                "/login_submit", query_string=query_params, data={"username": 
username, "password": password}
+            )
             assert response.status_code == 302
             if is_successful:
-                assert response.location == url_for("Airflow.index", 
token="token")
+                if not expected_redirect:
+                    expected_redirect = url_for("Airflow.index", token="token")
+                assert response.location == expected_redirect
             else:
                 assert response.location == 
url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"])

Reply via email to