This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f6e357a5fc Fallback to default value if `[aws] 
cloudwatch_task_handler_json_serializer` not set (#36851)
f6e357a5fc is described below

commit f6e357a5fcff8c791f9b2e03be968bf63b17e7c5
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Thu Jan 18 12:45:45 2024 +0400

    Fallback to default value if `[aws] 
cloudwatch_task_handler_json_serializer` not set (#36851)
---
 .../amazon/aws/log/cloudwatch_task_handler.py      |  4 +--
 .../amazon/aws/log/test_cloudwatch_task_handler.py | 33 ++++++++++++----------
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py 
b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index b210715cd6..3a801093ca 100644
--- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -98,13 +98,13 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
 
     def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
         super().set_context(ti)
-        _json_serialize = conf.getimport("aws", 
"cloudwatch_task_handler_json_serializer")
+        _json_serialize = conf.getimport("aws", 
"cloudwatch_task_handler_json_serializer", fallback=None)
         self.handler = watchtower.CloudWatchLogHandler(
             log_group_name=self.log_group,
             log_stream_name=self._render_filename(ti, ti.try_number),
             use_queues=not getattr(ti, "is_trigger_log_context", False),
             boto3_client=self.hook.get_conn(),
-            json_serialize_default=_json_serialize,
+            json_serialize_default=_json_serialize or json_serialize_legacy,
         )
 
     def close(self):
diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py 
b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
index 4a43aec429..d9cea48579 100644
--- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import contextlib
 import logging
 import time
 from datetime import datetime as dt, timedelta
@@ -177,10 +176,18 @@ class TestCloudwatchTaskHandler:
     @pytest.mark.parametrize(
         "conf_json_serialize, expected_serialized_output",
         [
-            (None, '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": 
null}'),
-            (
+            pytest.param(
+                
"airflow.providers.amazon.aws.log.cloudwatch_task_handler.json_serialize_legacy",
+                '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": 
null}',
+                id="json-serialize-legacy",
+            ),
+            pytest.param(
                 
"airflow.providers.amazon.aws.log.cloudwatch_task_handler.json_serialize",
                 '{"datetime": "2023-01-01T00:00:00+00:00", "customObject": 
"SomeCustomSerialization(...)"}',
+                id="json-serialize",
+            ),
+            pytest.param(
+                None, '{"datetime": "2023-01-01T00:00:00+00:00", 
"customObject": null}', id="not-set"
             ),
         ],
     )
@@ -193,12 +200,7 @@ class TestCloudwatchTaskHandler:
             def __repr__(self):
                 return "SomeCustomSerialization(...)"
 
-        with contextlib.ExitStack() as stack:
-            if conf_json_serialize:
-                stack.enter_context(
-                    conf_vars({("aws", 
"cloudwatch_task_handler_json_serializer"): conf_json_serialize})
-                )
-
+        with conf_vars({("aws", "cloudwatch_task_handler_json_serializer"): 
conf_json_serialize}):
             handler = self.cloudwatch_task_handler
             handler.set_context(self.ti)
             message = logging.LogRecord(
@@ -213,12 +215,13 @@ class TestCloudwatchTaskHandler:
                     "customObject": ToSerialize(),
                 },
             )
-            stack.enter_context(mock.patch("watchtower.threading.Thread"))
-            mock_queue = Mock()
-            stack.enter_context(mock.patch("watchtower.queue.Queue", 
return_value=mock_queue))
-            handler.handle(message)
-
-            mock_queue.put.assert_called_once_with({"message": 
expected_serialized_output, "timestamp": ANY})
+            with mock.patch("watchtower.threading.Thread"), 
mock.patch("watchtower.queue.Queue") as mq:
+                mock_queue = Mock()
+                mq.return_value = mock_queue
+                handler.handle(message)
+                mock_queue.put.assert_called_once_with(
+                    {"message": expected_serialized_output, "timestamp": ANY}
+                )
 
     def test_close_prevents_duplicate_calls(self):
         with mock.patch("watchtower.CloudWatchLogHandler.close") as 
mock_log_handler_close:

Reply via email to