This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 565be341afb Cache user object fetched per request in FAB auth manager 
for improved performance. (#60274)
565be341afb is described below

commit 565be341afb01585a3f27069727fd0cf703c956e
Author: Karthikeyan Singaravelan <[email protected]>
AuthorDate: Thu Jan 15 21:06:40 2026 +0530

    Cache user object fetched per request in FAB auth manager for improved 
performance. (#60274)
---
 .../api_fastapi/auth/middlewares/refresh_token.py     | 13 +++++++------
 devel-common/pyproject.toml                           |  1 +
 providers/fab/docs/index.rst                          |  1 +
 providers/fab/provider.yaml                           |  8 ++++++++
 providers/fab/pyproject.toml                          |  1 +
 .../providers/fab/auth_manager/fab_auth_manager.py    | 13 ++++++++++---
 .../src/airflow/providers/fab/get_provider_info.py    |  7 +++++++
 .../unit/fab/auth_manager/test_fab_auth_manager.py    | 19 ++++++++++++++++++-
 8 files changed, 53 insertions(+), 10 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py 
b/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
index 5705d14ba99..a8386f40138 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
@@ -44,9 +44,9 @@ class JWTRefreshMiddleware(BaseHTTPMiddleware):
         current_token = request.cookies.get(COOKIE_NAME_JWT_TOKEN)
         try:
             if current_token:
-                new_user = await self._refresh_user(current_token)
-                if new_user:
-                    request.state.user = new_user
+                new_user, current_user = await 
self._refresh_user(current_token)
+                if user := (new_user or current_user):
+                    request.state.user = user
 
             response = await call_next(request)
 
@@ -67,9 +67,10 @@ class JWTRefreshMiddleware(BaseHTTPMiddleware):
         return response
 
     @staticmethod
-    async def _refresh_user(current_token: str) -> BaseUser | None:
+    async def _refresh_user(current_token: str) -> tuple[BaseUser | None, 
BaseUser | None]:
         try:
             user = await resolve_user_from_token(current_token)
         except HTTPException:
-            return None
-        return get_auth_manager().refresh_user(user=user)
+            return None, None
+
+        return get_auth_manager().refresh_user(user=user), user
diff --git a/devel-common/pyproject.toml b/devel-common/pyproject.toml
index 4c390a557b1..2a5595d31d0 100644
--- a/devel-common/pyproject.toml
+++ b/devel-common/pyproject.toml
@@ -127,6 +127,7 @@ dependencies = [
     "types-setuptools>=80.0.0.20250429",
     "types-tabulate>=0.9.0.20240106",
     "types-toml>=0.10.8.20240310",
+    "types-cachetools>=6.2.0.20251022",
 ]
 "pytest" = [
     # General pytest devel tools
diff --git a/providers/fab/docs/index.rst b/providers/fab/docs/index.rst
index f394c174ec2..db53b744df2 100644
--- a/providers/fab/docs/index.rst
+++ b/providers/fab/docs/index.rst
@@ -121,6 +121,7 @@ PIP package                                 Version required
 ``jmespath``                                ``>=0.7.0; python_version < 
"3.13"``
 ``werkzeug``                                ``>=2.2,<4; python_version < 
"3.13"``
 ``wtforms``                                 ``>=3.0,<4; python_version < 
"3.13"``
+``cachetools``                              ``>=6.0; python_version < "3.13"``
 ``flask_limiter``                           ``>3,!=3.13,<4``
 ==========================================  
==========================================
 
diff --git a/providers/fab/provider.yaml b/providers/fab/provider.yaml
index ef4067a7a0c..1ed53d1ec8e 100644
--- a/providers/fab/provider.yaml
+++ b/providers/fab/provider.yaml
@@ -250,6 +250,14 @@ config:
         type: integer
         example: ~
         default: "1"
+      cache_ttl:
+        description: |
+          Number of seconds after which the user cache will expire to refetch 
updated user and
+          permissions.
+        version_added: 3.2.0
+        type: integer
+        example: ~
+        default: "30"
 
 auth-managers:
   - airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager
diff --git a/providers/fab/pyproject.toml b/providers/fab/pyproject.toml
index be1fc663189..2f322c76dab 100644
--- a/providers/fab/pyproject.toml
+++ b/providers/fab/pyproject.toml
@@ -81,6 +81,7 @@ dependencies = [
     "jmespath>=0.7.0; python_version < '3.13'",
     "werkzeug>=2.2,<4; python_version < '3.13'",
     "wtforms>=3.0,<4; python_version < '3.13'",
+    "cachetools>=6.0; python_version < '3.13'",
 
     # 
https://github.com/dpgaspar/Flask-AppBuilder/blob/release/4.6.3/setup.py#L54C8-L54C26
     # with an exclusion to account for 
https://github.com/alisaifee/flask-limiter/issues/479
diff --git 
a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py 
b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 491ca75bbbf..32a1b88f38b 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -22,6 +22,7 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Any
 from urllib.parse import urljoin
 
+from cachetools import TTLCache, cachedmethod
 from connexion import FlaskApi
 from fastapi import FastAPI
 from fastapi.middleware.wsgi import WSGIMiddleware
@@ -94,7 +95,7 @@ from airflow.providers.fab.www.utils import (
     get_fab_action_from_method_map,
     get_method_from_fab_action_map,
 )
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.yaml import safe_load
 
 if TYPE_CHECKING:
@@ -161,6 +162,7 @@ _MAP_MENU_ITEM_TO_FAB_RESOURCE_TYPE = {
     MenuItem.XCOMS: RESOURCE_XCOM,
 }
 
+CACHE_TTL = conf.getint("fab", "cache_ttl", fallback=30)
 
 if AIRFLOW_V_3_1_PLUS:
     from airflow.providers.fab.www.security.permissions import 
RESOURCE_HITL_DETAIL
@@ -176,6 +178,7 @@ class FabAuthManager(BaseAuthManager[User]):
     This auth manager is responsible for providing a backward compatible user 
management experience to users.
     """
 
+    cache: TTLCache = TTLCache(maxsize=1024, ttl=CACHE_TTL)
     appbuilder: AirflowAppBuilder | None = None
 
     def init_flask_resources(self) -> None:
@@ -255,9 +258,13 @@ class FabAuthManager(BaseAuthManager[User]):
 
         return current_user
 
+    @property
+    def session(self):
+        return self.appbuilder.session
+
+    @cachedmethod(lambda self: self.cache, key=lambda _, token: 
int(token["sub"]))
     def deserialize_user(self, token: dict[str, Any]) -> User:
-        with create_session() as session:
-            return session.scalars(select(User).where(User.id == 
int(token["sub"]))).one()
+        return self.session.scalars(select(User).where(User.id == 
int(token["sub"]))).one()
 
     def serialize_user(self, user: User) -> dict[str, Any]:
         return {"sub": str(user.id)}
diff --git a/providers/fab/src/airflow/providers/fab/get_provider_info.py 
b/providers/fab/src/airflow/providers/fab/get_provider_info.py
index 3ca0fa9a662..a7ab9325274 100644
--- a/providers/fab/src/airflow/providers/fab/get_provider_info.py
+++ b/providers/fab/src/airflow/providers/fab/get_provider_info.py
@@ -177,6 +177,13 @@ def get_provider_info():
                         "example": None,
                         "default": "1",
                     },
+                    "cache_ttl": {
+                        "description": "Number of seconds after which the user 
cache will expire to refetch updated user and\npermissions.\n",
+                        "version_added": "3.2.0",
+                        "type": "integer",
+                        "example": None,
+                        "default": "30",
+                    },
                 },
             }
         },
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py 
b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
index ee1c3a2291d..0170d14923f 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import time
 from contextlib import contextmanager, suppress
 from itertools import chain
 from typing import TYPE_CHECKING
@@ -34,6 +35,7 @@ from airflow.providers.fab.www.utils import 
get_fab_auth_manager
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.utils.db import resetdb
 
+from tests_common.test_utils.asserts import assert_queries_count
 from tests_common.test_utils.config import conf_vars
 from unit.fab.auth_manager.api_endpoints.api_connexion_utils import 
create_user, delete_user
 
@@ -197,9 +199,24 @@ class TestFabAuthManager:
             with user_set(minimal_app_for_auth_api, flask_g_user):
                 assert auth_manager.get_user() == flask_g_user
 
+    @conf_vars({("fab", "cache_ttl"): "1"})
     def test_deserialize_user(self, flask_app, auth_manager_with_appbuilder):
+        """Test user objects are cached and that the cache expires after 
configured TTL."""
         user = create_user(flask_app, "test")
-        result = auth_manager_with_appbuilder.deserialize_user({"sub": 
str(user.id)})
+        with assert_queries_count(2):
+            result = auth_manager_with_appbuilder.deserialize_user({"sub": 
str(user.id)})
+
+        assert user.get_id() == result.get_id()
+
+        with assert_queries_count(0):
+            result = auth_manager_with_appbuilder.deserialize_user({"sub": 
str(user.id)})
+
+        assert user.get_id() == result.get_id()
+
+        time.sleep(1)
+        with assert_queries_count(2):
+            result = auth_manager_with_appbuilder.deserialize_user({"sub": 
str(user.id)})
+
         assert user.get_id() == result.get_id()
 
     def test_serialize_user(self, flask_app, auth_manager_with_appbuilder):

Reply via email to