This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new f6e357a5fc Fallback to default value if `[aws] cloudwatch_task_handler_json_serializer` not set (#36851) f6e357a5fc is described below commit f6e357a5fcff8c791f9b2e03be968bf63b17e7c5 Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Thu Jan 18 12:45:45 2024 +0400 Fallback to default value if `[aws] cloudwatch_task_handler_json_serializer` not set (#36851) --- .../amazon/aws/log/cloudwatch_task_handler.py | 4 +-- .../amazon/aws/log/test_cloudwatch_task_handler.py | 33 ++++++++++++---------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index b210715cd6..3a801093ca 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -98,13 +98,13 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): def set_context(self, ti: TaskInstance, *, identifier: str | None = None): super().set_context(ti) - _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer") + _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None) self.handler = watchtower.CloudWatchLogHandler( log_group_name=self.log_group, log_stream_name=self._render_filename(ti, ti.try_number), use_queues=not getattr(ti, "is_trigger_log_context", False), boto3_client=self.hook.get_conn(), - json_serialize_default=_json_serialize, + json_serialize_default=_json_serialize or json_serialize_legacy, ) def close(self): diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index 4a43aec429..d9cea48579 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import contextlib import logging import time from datetime import datetime as dt, timedelta @@ -177,10 +176,18 @@ class TestCloudwatchTaskHandler: @pytest.mark.parametrize( "conf_json_serialize, expected_serialized_output", [ - (None, '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": null}'), - ( + pytest.param( + "airflow.providers.amazon.aws.log.cloudwatch_task_handler.json_serialize_legacy", + '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": null}', + id="json-serialize-legacy", + ), + pytest.param( "airflow.providers.amazon.aws.log.cloudwatch_task_handler.json_serialize", '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": "SomeCustomSerialization(...)"}', + id="json-serialize", + ), + pytest.param( + None, '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": null}', id="not-set" ), ], ) @@ -193,12 +200,7 @@ class TestCloudwatchTaskHandler: def __repr__(self): return "SomeCustomSerialization(...)" - with contextlib.ExitStack() as stack: - if conf_json_serialize: - stack.enter_context( - conf_vars({("aws", "cloudwatch_task_handler_json_serializer"): conf_json_serialize}) - ) - + with conf_vars({("aws", "cloudwatch_task_handler_json_serializer"): conf_json_serialize}): handler = self.cloudwatch_task_handler handler.set_context(self.ti) message = logging.LogRecord( @@ -213,12 +215,13 @@ class TestCloudwatchTaskHandler: "customObject": ToSerialize(), }, ) - stack.enter_context(mock.patch("watchtower.threading.Thread")) - mock_queue = Mock() - stack.enter_context(mock.patch("watchtower.queue.Queue", return_value=mock_queue)) - handler.handle(message) - - mock_queue.put.assert_called_once_with({"message": expected_serialized_output, "timestamp": ANY}) + with mock.patch("watchtower.threading.Thread"), mock.patch("watchtower.queue.Queue") as mq: + mock_queue = Mock() + mq.return_value = mock_queue + handler.handle(message) + mock_queue.put.assert_called_once_with( + {"message": expected_serialized_output, "timestamp": ANY} + ) def test_close_prevents_duplicate_calls(self): with mock.patch("watchtower.CloudWatchLogHandler.close") as mock_log_handler_close: