This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 2b993bea5bea5aec0a0547b258ce2d5bf724282b
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Jul 29 15:04:37 2025 +0530

    Fix custom xcom backend serialize when BaseXCom.get_all is used (#53814)
    
    (cherry picked from commit a8c4ba35351afeef8f15c9627e04241a29054fe4)
---
 task-sdk/src/airflow/sdk/bases/xcom.py             | 11 ++--
 .../task_sdk/execution_time/test_task_runner.py    | 75 ++++++++++++++++++++++
 2 files changed, 82 insertions(+), 4 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py 
b/task-sdk/src/airflow/sdk/bases/xcom.py
index 82df8d151ab..423fde12012 100644
--- a/task-sdk/src/airflow/sdk/bases/xcom.py
+++ b/task-sdk/src/airflow/sdk/bases/xcom.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import collections
 from typing import Any, Protocol
 
 import structlog
@@ -30,6 +31,9 @@ from airflow.sdk.execution_time.comms import (
     XComSequenceSliceResult,
 )
 
+# Lightweight wrapper for XCom values
+_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value")
+
 log = structlog.get_logger(logger_name="task")
 
 
@@ -290,7 +294,6 @@ class BaseXCom:
         :return: List of all XCom values if found.
         """
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
-        from airflow.serialization.serde import deserialize
 
         msg = SUPERVISOR_COMMS.send(
             msg=GetXComSequenceSlice(
@@ -307,10 +310,10 @@ class BaseXCom:
         if not isinstance(msg, XComSequenceSliceResult):
             raise TypeError(f"Expected XComSequenceSliceResult, received: 
{type(msg)} {msg}")
 
-        result = deserialize(msg.root)
-        if not result:
+        if not msg.root:
             return None
-        return result
+
+        return [cls.deserialize_value(_XComValueWrapper(value)) for value in 
msg.root]
 
     @staticmethod
     def serialize_value(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 362b89e3ad0..4b445404d8a 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -2014,6 +2014,81 @@ class TestXComAfterTaskExecution:
             for x in mock_supervisor_comms.send.call_args_list
         )
 
+    def test_get_all_uses_custom_deserialize_value(self, 
mock_supervisor_comms):
+        """
+        Tests that XCom.get_all() calls the custom deserialize_value method.
+        """
+
+        class CustomXCom(BaseXCom):
+            @classmethod
+            def deserialize_value(cls, result):
+                """Custom deserialization that adds a prefix to show it was 
called."""
+                original_value = super().deserialize_value(result)
+                return f"from custom xcom deserialize:{original_value}"
+
+        serialized_values = ["value1", "value2", "value3"]
+        mock_supervisor_comms.send.return_value = 
XComSequenceSliceResult(root=serialized_values)
+
+        result = CustomXCom.get_all(key="test_key", dag_id="test_dag", 
task_id="test_task", run_id="test_run")
+
+        expected = [
+            "from custom xcom deserialize:value1",
+            "from custom xcom deserialize:value2",
+            "from custom xcom deserialize:value3",
+        ]
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        ("include_prior_dates", "expected_value"),
+        [
+            pytest.param(True, True, id="include_prior_dates_true"),
+            pytest.param(False, False, id="include_prior_dates_false"),
+            pytest.param(None, False, id="include_prior_dates_default"),
+        ],
+    )
+    def test_xcom_pull_with_include_prior_dates(
+        self,
+        create_runtime_ti,
+        mock_supervisor_comms,
+        include_prior_dates,
+        expected_value,
+    ):
+        """Test that xcom_pull with include_prior_dates parameter correctly 
behaves as we expect."""
+        task = BaseOperator(task_id="pull_task")
+        runtime_ti = create_runtime_ti(task=task)
+
+        value = {"previous_run_data": "test_value"}
+        ser_value = BaseXCom.serialize_value(value)
+
+        def mock_send_side_effect(*args, **kwargs):
+            msg = kwargs.get("msg") or args[0]
+            if isinstance(msg, GetXComSequenceSlice):
+                assert msg.include_prior_dates is expected_value, (
+                    f"include_prior_dates should be {expected_value} in 
GetXComSequenceSlice"
+                )
+                return XComSequenceSliceResult(root=[ser_value])
+            return XComResult(key="test_key", value=None)
+
+        mock_supervisor_comms.send.side_effect = mock_send_side_effect
+        kwargs = {"key": "test_key", "task_ids": "previous_task"}
+        if include_prior_dates is not None:
+            kwargs["include_prior_dates"] = include_prior_dates
+        result = runtime_ti.xcom_pull(**kwargs)
+        assert result == value
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            msg=GetXComSequenceSlice(
+                key="test_key",
+                dag_id=runtime_ti.dag_id,
+                run_id=runtime_ti.run_id,
+                task_id="previous_task",
+                start=None,
+                stop=None,
+                step=None,
+                include_prior_dates=expected_value,
+            ),
+        )
+
 
 class TestDagParamRuntime:
     DEFAULT_ARGS = {

Reply via email to