This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 2b993bea5bea5aec0a0547b258ce2d5bf724282b Author: Amogh Desai <[email protected]> AuthorDate: Tue Jul 29 15:04:37 2025 +0530 Fix custom xcom backend serialize when BaseXCom.get_all is used (#53814) (cherry picked from commit a8c4ba35351afeef8f15c9627e04241a29054fe4) --- task-sdk/src/airflow/sdk/bases/xcom.py | 11 ++-- .../task_sdk/execution_time/test_task_runner.py | 75 ++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 82df8d151ab..423fde12012 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -17,6 +17,7 @@ from __future__ import annotations +import collections from typing import Any, Protocol import structlog @@ -30,6 +31,9 @@ from airflow.sdk.execution_time.comms import ( XComSequenceSliceResult, ) +# Lightweight wrapper for XCom values +_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value") + log = structlog.get_logger(logger_name="task") @@ -290,7 +294,6 @@ class BaseXCom: :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - from airflow.serialization.serde import deserialize msg = SUPERVISOR_COMMS.send( msg=GetXComSequenceSlice( @@ -307,10 +310,10 @@ class BaseXCom: if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - result = deserialize(msg.root) - if not result: + if not msg.root: return None - return result + + return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root] @staticmethod def serialize_value( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 362b89e3ad0..4b445404d8a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2014,6 +2014,81 @@ class TestXComAfterTaskExecution: for x in mock_supervisor_comms.send.call_args_list ) + def test_get_all_uses_custom_deserialize_value(self, mock_supervisor_comms): + """ + Tests that XCom.get_all() calls the custom deserialize_value method. + """ + + class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, result): + """Custom deserialization that adds a prefix to show it was called.""" + original_value = super().deserialize_value(result) + return f"from custom xcom deserialize:{original_value}" + + serialized_values = ["value1", "value2", "value3"] + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=serialized_values) + + result = CustomXCom.get_all(key="test_key", dag_id="test_dag", task_id="test_task", run_id="test_run") + + expected = [ + "from custom xcom deserialize:value1", + "from custom xcom deserialize:value2", + "from custom xcom deserialize:value3", + ] + assert result == expected + + @pytest.mark.parametrize( + ("include_prior_dates", "expected_value"), + [ + pytest.param(True, True, id="include_prior_dates_true"), + pytest.param(False, False, id="include_prior_dates_false"), + pytest.param(None, False, id="include_prior_dates_default"), + ], + ) + def test_xcom_pull_with_include_prior_dates( + self, + create_runtime_ti, + mock_supervisor_comms, + include_prior_dates, + expected_value, + ): + """Test that xcom_pull with include_prior_dates parameter correctly behaves as we expect.""" + task = BaseOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + + value = {"previous_run_data": "test_value"} + ser_value = BaseXCom.serialize_value(value) + + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + if isinstance(msg, GetXComSequenceSlice): + assert msg.include_prior_dates is expected_value, ( + f"include_prior_dates should be {expected_value} in GetXComSequenceSlice" + ) + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="test_key", value=None) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect + kwargs = {"key": "test_key", "task_ids": "previous_task"} + if include_prior_dates is not None: + kwargs["include_prior_dates"] = include_prior_dates + result = runtime_ti.xcom_pull(**kwargs) + assert result == value + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComSequenceSlice( + key="test_key", + dag_id=runtime_ti.dag_id, + run_id=runtime_ti.run_id, + task_id="previous_task", + start=None, + stop=None, + step=None, + include_prior_dates=expected_value, + ), + ) + class TestDagParamRuntime: DEFAULT_ARGS = {
