kaxil commented on code in PR #55972:
URL: https://github.com/apache/airflow/pull/55972#discussion_r2370826321


##########
task-sdk/tests/task_sdk/execution_time/test_context_cache.py:
##########
@@ -0,0 +1,332 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+import pytest
+
+from airflow.sdk.definitions.connection import Connection
+from airflow.sdk.execution_time.cache import SecretCache
+from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult
+from airflow.sdk.execution_time.context import (
+    _delete_variable,
+    _get_connection,
+    _get_variable,
+    _set_variable,
+)
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestConnectionCacheIntegration:
+    """Test the integration of SecretCache with connection access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_uses_cache_when_available(self, 
mock_ensure_backends):
+        """Test that _get_connection uses cache when connection is cached."""
+        conn_id = "test_conn"
+        uri = "postgres://user:pass@host:5432/db"
+
+        SecretCache.save_connection_uri(conn_id, uri)
+
+        result = _get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "postgres"
+        assert result.host == "host"
+        assert result.login == "user"
+        assert result.password == "pass"
+        assert result.port == 5432
+        assert result.schema == "db"
+
+        mock_ensure_backends.assert_not_called()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_from_backend_saves_to_cache(self, 
mock_ensure_backends):
+        """Test that connection from secrets backend is retrieved correctly 
and cached."""
+        conn_id = "test_conn"
+        conn = Connection(conn_id=conn_id, conn_type="mysql", host="host", 
port=3306)
+
+        mock_backend = MagicMock(spec=["get_connection"])
+        mock_backend.get_connection.return_value = conn
+        mock_ensure_backends.return_value = [mock_backend]
+
+        result = _get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        mock_backend.get_connection.assert_called_once_with(conn_id=conn_id)
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_from_api(self, mock_ensure_backends, 
mock_supervisor_comms):
+        """Test that connection from API server works correctly."""
+        conn_id = "test_conn"
+        conn_result = ConnectionResult(
+            conn_id=conn_id,
+            conn_type="mysql",
+            host="host",
+            port=3306,
+            login="user",
+            password="pass",
+        )
+
+        mock_ensure_backends.return_value = []
+
+        mock_supervisor_comms.send.return_value = conn_result
+
+        result = _get_connection(conn_id)
+
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        # Called for GetConnection (and possibly MaskSecret)
+        assert mock_supervisor_comms.send.call_count >= 1
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+    @patch("airflow.sdk.execution_time.context.mask_secret")
+    def test_get_connection_masks_secrets(self, mock_mask_secret):
+        """Test that connection secrets are masked from logs."""
+        conn_id = "test_conn"
+        conn = Connection(
+            conn_id=conn_id, conn_type="mysql", login="user", 
password="password", extra='{"key": "value"}'
+        )
+
+        mock_backend = MagicMock(spec=["get_connection"])
+        mock_backend.get_connection.return_value = conn
+
+        with patch(
+            
"airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", 
return_value=[mock_backend]
+        ):
+            result = _get_connection(conn_id)
+
+            assert result.conn_id == conn_id
+            # Check that password and extra were masked
+            mock_mask_secret.assert_has_calls(
+                [
+                    call("password"),
+                    call('{"key": "value"}'),
+                ],
+                any_order=True,
+            )
+
+
+class TestVariableCacheIntegration:
+    """Test the integration of SecretCache with variable access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_uses_cache_when_available(self, 
mock_ensure_backends):
+        """Test that _get_variable uses cache when variable is cached."""
+        key = "test_key"
+        value = "test_value"
+        SecretCache.save_variable(key, value)
+
+        result = _get_variable(key, deserialize_json=False)
+        assert result == value
+        mock_ensure_backends.assert_not_called()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_from_backend_saves_to_cache(self, 
mock_ensure_backends):
+        """Test that variable from secrets backend is saved to cache."""
+        key = "test_key"
+        value = "test_value"
+
+        mock_backend = MagicMock(spec=["get_variable"])
+        mock_backend.get_variable.return_value = value
+        mock_ensure_backends.return_value = [mock_backend]
+
+        result = _get_variable(key, deserialize_json=False)
+        assert result == value
+        mock_backend.get_variable.assert_called_once_with(key=key)
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == value
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_from_api_saves_to_cache(self, mock_ensure_backends, 
mock_supervisor_comms):
+        """Test that variable from API server is saved to cache."""
+        key = "test_key"
+        value = "test_value"
+        var_result = VariableResult(key=key, value=value)
+
+        mock_ensure_backends.return_value = []
+        mock_supervisor_comms.send.return_value = var_result
+
+        result = _get_variable(key, deserialize_json=False)
+
+        assert result == value
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == value
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_with_json_deserialization(self, 
mock_ensure_backends):
+        """Test that _get_variable handles JSON deserialization correctly with 
cache."""
+        key = "test_key"
+        json_value = '{"key": "value", "number": 42}'
+        SecretCache.save_variable(key, json_value)
+
+        result = _get_variable(key, deserialize_json=True)
+        assert result == {"key": "value", "number": 42}
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == json_value
+
+    def test_set_variable_invalidates_cache(self, mock_supervisor_comms):
+        """Test that _set_variable invalidates the cache."""
+        key = "test_key"
+        old_value = "old_value"
+        new_value = "new_value"
+        SecretCache.save_variable(key, old_value)
+
+        _set_variable(key, new_value)
+        mock_supervisor_comms.send.assert_called_once()
+        with pytest.raises(SecretCache.NotPresentException):
+            SecretCache.get_variable(key)
+
+    def test_delete_variable_invalidates_cache(self, mock_supervisor_comms):
+        """Test that _delete_variable invalidates the cache."""
+        key = "test_key"
+        value = "test_value"
+        SecretCache.save_variable(key, value)
+
+        from airflow.sdk.execution_time.comms import OKResponse
+
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        _delete_variable(key)
+        mock_supervisor_comms.send.assert_called_once()
+        with pytest.raises(SecretCache.NotPresentException):
+            SecretCache.get_variable(key)
+
+
+class TestAsyncConnectionCache:
+    """Test the integration of SecretCache with async connection access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    @pytest.mark.asyncio
+    async def test_async_get_connection_uses_cache(self):
+        """Test that _async_get_connection uses cache when connection is 
cached."""
+        from airflow.sdk.execution_time.context import _async_get_connection
+
+        conn_id = "test_conn"
+        uri = "postgres://user:pass@host:5432/db"
+
+        SecretCache.save_connection_uri(conn_id, uri)
+
+        result = await _async_get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "postgres"
+        assert result.host == "host"
+        assert result.login == "user"
+        assert result.password == "pass"
+        assert result.port == 5432
+        assert result.schema == "db"
+
+    @pytest.mark.asyncio
+    async def test_async_get_connection_from_api(self, mock_supervisor_comms):
+        """Test that async connection from API server works correctly."""
+        from airflow.sdk.execution_time.context import _async_get_connection
+
+        conn_id = "test_conn"
+        conn_result = ConnectionResult(
+            conn_id=conn_id,
+            conn_type="mysql",
+            host="host",
+            port=3306,
+        )
+
+        # Configure asend to return the conn_result when awaited
+        mock_supervisor_comms.asend = AsyncMock(return_value=conn_result)
+
+        result = await _async_get_connection(conn_id)
+
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        mock_supervisor_comms.asend.assert_called_once()
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+
+class TestCacheDisabled:
+    """Test behavior when cache is disabled."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "false"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod

Review Comment:
   Not needed for teardown since we just clear the cache dict.



##########
task-sdk/tests/task_sdk/execution_time/test_context_cache.py:
##########
@@ -0,0 +1,332 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+import pytest
+
+from airflow.sdk.definitions.connection import Connection
+from airflow.sdk.execution_time.cache import SecretCache
+from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult
+from airflow.sdk.execution_time.context import (
+    _delete_variable,
+    _get_connection,
+    _get_variable,
+    _set_variable,
+)
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestConnectionCacheIntegration:
+    """Test the integration of SecretCache with connection access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod

Review Comment:
   Not needed for teardown since we just clear the cache dict.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to