This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 59137f980fd Fix bug in XComModel deserialization for value serialized 
in Task SDK (#47961)
59137f980fd is described below

commit 59137f980fd9baad6a2acc50c0f0eea28f042a8c
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Mar 19 23:07:44 2025 +0530

    Fix bug in XComModel deserialization for value serialized in Task SDK 
(#47961)
    
    Currently, the Task SDK uses `airflow.serialization` module to serialize 
XCom value during execution. This is because
    we have agreed with the contract of the execution client (Python, Golang 
and other SDK clients) being responsible for serialization and de-serialization.
    
    To do this, as part of https://github.com/apache/airflow/issues/45481, we 
skip XCom serialization from XComModel and instead directly store in the DB.
    
    This works well when deserialization is done on Execution side itself. 
However, for some cases we deserialize it on the Scheduler or API server only. 
Some of those:
    - Extra Operator Links
    - Branching & Skipping logic in `NotPreviouslySkippedDep` where we pull 
XCom.
    
    closes https://github.com/apache/airflow/issues/47907
    closes https://github.com/apache/airflow/pull/47918
---
 airflow/models/xcom.py                             | 37 +++++++++++-
 .../api_fastapi/execution_api/routes/test_xcoms.py | 65 +++++++++++++++++++++-
 tests/models/test_xcom.py                          | 51 +++++++++++++++++
 3 files changed, 150 insertions(+), 3 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index a8f1d3ee310..e2fdf49ad6f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -331,11 +331,44 @@ class XComModel(TaskInstanceDependencies):
 
     @staticmethod
     def deserialize_value(result) -> Any:
-        """Deserialize XCom value from str objects."""
+        """
+        Deserialize XCom value from a database result.
+
+        If deserialization fails, the raw value is returned, which must still 
be a valid Python JSON-compatible
+        type (e.g., ``dict``, ``list``, ``str``, ``int``, ``float``, or 
``bool``).
+
+        XCom values are stored as JSON in the database, and SQLAlchemy 
automatically handles
+        serialization (``json.dumps``) and deserialization (``json.loads``). 
However, we
+        use a custom encoder for serialization (``serialize_value``) and 
deserialization to handle special
+        cases, such as encoding tuples via the Airflow Serialization module. 
These must be decoded
+        using ``XComDecoder`` to restore original types.
+
+        Some XCom values, such as those set via the Task Execution API, bypass 
``serialize_value``
+        and are stored directly in JSON format. Since these values are already 
deserialized
+        by SQLAlchemy, they are returned as-is.
+
+        **Example: Handling a tuple**:
+
+        .. code-block:: python
+
+            original_value = (1, 2, 3)
+            serialized_value = XComModel.serialize_value(original_value)
+            print(serialized_value)
+            # '{"__classname__": "builtins.tuple", "__version__": 1, 
"__data__": [1, 2, 3]}'
+
+        This serialized value is stored in the database. When deserialized, 
the value is restored to the original tuple.
+
+        :param result: The XCom database row or object containing a ``value`` 
attribute.
+        :return: The deserialized Python object.
+        """
         if result.value is None:
             return None
 
-        return json.loads(result.value, cls=XComDecoder)
+        try:
+            return json.loads(result.value, cls=XComDecoder)
+        except (ValueError, TypeError):
+            # Already deserialized (e.g., set via Task Execution API)
+            return result.value
 
 
 class LazyXComSelectSequence(LazySelectSequence[Any]):
diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py 
b/tests/api_fastapi/execution_api/routes/test_xcoms.py
index ae21fefdbf4..c2b49841b3a 100644
--- a/tests/api_fastapi/execution_api/routes/test_xcoms.py
+++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py
@@ -29,7 +29,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import 
XComResponse
 from airflow.models.dagrun import DagRun
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XComModel
-from airflow.serialization.serde import serialize
+from airflow.serialization.serde import deserialize, serialize
 from airflow.utils.session import create_session
 
 pytestmark = pytest.mark.db_test
@@ -163,6 +163,69 @@ class TestXComsSetEndpoint:
         task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, 
dag_id=ti.dag_id).one_or_none()
         assert task_map is None, "Should not be mapped"
 
+    @pytest.mark.parametrize(
+        "orig_value, ser_value, deser_value",
+        [
+            pytest.param(1, 1, 1, id="int"),
+            pytest.param(1.0, 1.0, 1.0, id="float"),
+            pytest.param("string", "string", "string", id="str"),
+            pytest.param(True, True, True, id="bool"),
+            pytest.param({"key": "value"}, {"key": "value"}, {"key": "value"}, 
id="dict"),
+            pytest.param([1, 2], [1, 2], [1, 2], id="list"),
+            pytest.param(
+                (1, 2),
+                # Client serializes tuple as encoded list, send the encoded 
list to the API
+                {"__classname__": "builtins.tuple", "__data__": [1, 2], 
"__version__": 1},
+                # The API will send the encoded list to the DB and sends the 
same encoded list back
+                # during the response to the client as it is the clients 
responsibility to
+                # serialize it into a JSON object & deserialize value into a 
native object.
+                {"__classname__": "builtins.tuple", "__data__": [1, 2], 
"__version__": 1},
+                id="tuple",
+            ),
+        ],
+    )
+    def test_xcom_round_trip(self, client, create_task_instance, session, 
orig_value, ser_value, deser_value):
+        """
+        Test that deserialization works when XCom values are stored directly 
in the DB with API Server.
+
+        This tests the case where the XCom value is stored from the Task API 
where the value is serialized
+        via Client SDK into JSON object and passed via the API Server to the 
DB. It by-passes
+        the XComModel.serialize_value and stores valid Python JSON compatible 
objects to DB.
+
+        This test is to ensure that the deserialization works correctly in 
this case as well as
+        checks that the value is stored correctly before it hits the API.
+        """
+
+        ti = create_task_instance()
+        session.commit()
+
+        # Serialize the value to simulate the client SDK
+        value = serialize(orig_value)
+
+        # Test that the value is serialized correctly
+        assert value == ser_value
+
+        response = client.post(
+            f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
+            json=value,
+        )
+
+        assert response.status_code == 201
+
+        stored_value = XComModel.get_many(
+            key="xcom_1",
+            dag_ids=ti.dag_id,
+            task_ids=ti.task_id,
+            run_id=ti.run_id,
+            session=session,
+        ).first()
+        deserialized_value = XComModel.deserialize_value(stored_value)
+
+        assert deserialized_value == deser_value
+
+        # Ensure that the deserialized value on the client side is the same as 
the original value
+        assert deserialize(deserialized_value) == orig_value
+
     def test_xcom_set_mapped(self, client, create_task_instance, session):
         ti = create_task_instance()
         session.commit()
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index c01c6a88c05..ce407d1fc69 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -370,3 +370,54 @@ class TestXComClear:
             session=session,
         )
         assert session.query(XComModel).count() == 1
+
+
+class TestXComRoundTrip:
+    @pytest.mark.parametrize(
+        "value, expected_value",
+        [
+            pytest.param(1, 1, id="int"),
+            pytest.param(1.0, 1.0, id="float"),
+            pytest.param("string", "string", id="str"),
+            pytest.param(True, True, id="bool"),
+            pytest.param({"key": "value"}, {"key": "value"}, id="dict"),
+            pytest.param([1, 2, 3], [1, 2, 3], id="list"),
+            pytest.param((1, 2, 3), (1, 2, 3), id="tuple"),  # tuple is 
preserved
+            pytest.param(None, None, id="none"),
+        ],
+    )
+    def test_xcom_round_trip(self, value, expected_value, 
push_simple_json_xcom, task_instance, session):
+        """Test that XComModel serialization and deserialization work as 
expected."""
+        push_simple_json_xcom(ti=task_instance, key="xcom_1", value=value)
+
+        stored_value = XComModel.get_many(
+            key="xcom_1",
+            dag_ids=task_instance.dag_id,
+            task_ids=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        ).first()
+        deserialized_value = XComModel.deserialize_value(stored_value)
+
+        assert deserialized_value == expected_value
+
+    @pytest.mark.parametrize(
+        "value, expected_value",
+        [
+            pytest.param(1, 1, id="int"),
+            pytest.param(1.0, 1.0, id="float"),
+            pytest.param("string", "string", id="str"),
+            pytest.param(True, True, id="bool"),
+            pytest.param({"key": "value"}, {"key": "value"}, id="dict"),
+            pytest.param([1, 2, 3], [1, 2, 3], id="list"),
+            pytest.param((1, 2, 3), (1, 2, 3), id="tuple"),  # tuple is 
preserved
+            pytest.param(None, None, id="none"),
+        ],
+    )
+    def test_xcom_deser_fallback(self, value, expected_value):
+        """Test fallback in deserialization."""
+
+        mock_xcom = MagicMock(value=value)
+        deserialized_value = XComModel.deserialize_value(mock_xcom)
+
+        assert deserialized_value == expected_value

Reply via email to