HyunWooZZ commented on code in PR #60689:
URL: https://github.com/apache/airflow/pull/60689#discussion_r2700980684


##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py:
##########
@@ -1471,3 +1472,64 @@ def test_cancel_queries(self, mock_cancel_execution):
 
         assert mock_cancel_execution.call_count == 3
         mock_cancel_execution.assert_has_calls([call("query-1"), 
call("query-2"), call("query-3")])
+
+    def test_make_api_call_passes_timeout_to_requests(self, mock_requests):
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", 
http_timeout_seconds=12)
+
+        resp = mock.MagicMock()
+        resp.status_code = 200
+        resp.raise_for_status.return_value = None
+        resp.json.return_value = {"ok": True}
+        mock_requests.request.return_value = resp
+
+        hook._make_api_call_with_retries("GET", API_URL, HEADERS)
+
+        mock_requests.request.assert_called_once_with(
+            method="get",
+            url=API_URL,
+            headers=HEADERS,
+            params=None,
+            json=None,
+            timeout=12.0,
+        )
+
+
+    @pytest.mark.asyncio
+    async def 
test_make_api_call_with_retries_async_passes_timeout_to_clientsession(self):
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", 
http_timeout_seconds=7)
+
+        with mock.patch(f"{MODULE_PATH}.aiohttp.ClientSession") as 
client_session_cls:
+            session_cm = mock.MagicMock()
+            client_session_cls.return_value.__aenter__ = 
AsyncMock(return_value=session_cm)
+
+            req_cm = mock.MagicMock()
+            session_cm.request.return_value = req_cm
+
+            resp = mock.MagicMock()
+            resp.status = 200
+            resp.raise_for_status.return_value = None
+            resp.json = AsyncMock(return_value=GET_RESPONSE)
+            req_cm.__aenter__ = AsyncMock(return_value=resp)
+
+            await hook._make_api_call_with_retries_async("GET", API_URL, 
HEADERS)
+
+            _, kwargs = client_session_cls.call_args
+            timeout_obj = kwargs["timeout"]
+            assert isinstance(timeout_obj, aiohttp.ClientTimeout)
+            assert timeout_obj.total == 7.0
+
+
+    @pytest.mark.asyncio
+    async def 
test_make_api_call_with_retries_async_retries_on_timeout_error(self, 
mock_async_request):
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", 
http_timeout_seconds=7)
+
+        mock_async_request.__aenter__.side_effect = [
+            asyncio.TimeoutError(),
+            create_async_request_client_response_success(json=GET_RESPONSE, 
status_code=200),
+        ]
+
+        status, data = await hook._make_api_call_with_retries_async("GET", 
API_URL, HEADERS)
+
+        assert status == 200
+        assert data == GET_RESPONSE
+        assert mock_async_request.__aenter__.call_count == 2

Review Comment:
   I'll add a newline at the end of the file.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to