This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 59137f980fd Fix bug in XComModel deserialization for value serialized
in Task SDK (#47961)
59137f980fd is described below
commit 59137f980fd9baad6a2acc50c0f0eea28f042a8c
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Mar 19 23:07:44 2025 +0530
Fix bug in XComModel deserialization for value serialized in Task SDK
(#47961)
Currently, the Task SDK uses `airflow.serialization` module to serialize
XCom value during execution. This is because
we have agreed with the contract of the execution client (Python, Golang
and other SDK clients) being responsible for serialization and de-serialization.
To do this, as part of https://github.com/apache/airflow/issues/45481, we
skip XCom serialization from XComModel and instead directly store in the DB.
This works well when deserialization is done on Execution side itself.
However, for some cases we deserialize it on the Scheduler or API server only.
Some of those:
- Extra Operator Links
- Branching & Skipping logic in `NotPreviouslySkippedDep` where we pull
XCom.
closes https://github.com/apache/airflow/issues/47907
closes https://github.com/apache/airflow/pull/47918
---
airflow/models/xcom.py | 37 +++++++++++-
.../api_fastapi/execution_api/routes/test_xcoms.py | 65 +++++++++++++++++++++-
tests/models/test_xcom.py | 51 +++++++++++++++++
3 files changed, 150 insertions(+), 3 deletions(-)
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index a8f1d3ee310..e2fdf49ad6f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -331,11 +331,44 @@ class XComModel(TaskInstanceDependencies):
@staticmethod
def deserialize_value(result) -> Any:
- """Deserialize XCom value from str objects."""
+ """
+ Deserialize XCom value from a database result.
+
+ If deserialization fails, the raw value is returned, which must still
be a valid Python JSON-compatible
+ type (e.g., ``dict``, ``list``, ``str``, ``int``, ``float``, or
``bool``).
+
+ XCom values are stored as JSON in the database, and SQLAlchemy
automatically handles
+ serialization (``json.dumps``) and deserialization (``json.loads``).
However, we
+ use a custom encoder for serialization (``serialize_value``) and
deserialization to handle special
+ cases, such as encoding tuples via the Airflow Serialization module.
These must be decoded
+ using ``XComDecoder`` to restore original types.
+
+ Some XCom values, such as those set via the Task Execution API, bypass
``serialize_value``
+ and are stored directly in JSON format. Since these values are already
deserialized
+ by SQLAlchemy, they are returned as-is.
+
+ **Example: Handling a tuple**:
+
+ .. code-block:: python
+
+ original_value = (1, 2, 3)
+ serialized_value = XComModel.serialize_value(original_value)
+ print(serialized_value)
+ # '{"__classname__": "builtins.tuple", "__version__": 1,
"__data__": [1, 2, 3]}'
+
+ This serialized value is stored in the database. When deserialized,
the value is restored to the original tuple.
+
+ :param result: The XCom database row or object containing a ``value``
attribute.
+ :return: The deserialized Python object.
+ """
if result.value is None:
return None
- return json.loads(result.value, cls=XComDecoder)
+ try:
+ return json.loads(result.value, cls=XComDecoder)
+ except (ValueError, TypeError):
+ # Already deserialized (e.g., set via Task Execution API)
+ return result.value
class LazyXComSelectSequence(LazySelectSequence[Any]):
diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py
b/tests/api_fastapi/execution_api/routes/test_xcoms.py
index ae21fefdbf4..c2b49841b3a 100644
--- a/tests/api_fastapi/execution_api/routes/test_xcoms.py
+++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py
@@ -29,7 +29,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import
XComResponse
from airflow.models.dagrun import DagRun
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
-from airflow.serialization.serde import serialize
+from airflow.serialization.serde import deserialize, serialize
from airflow.utils.session import create_session
pytestmark = pytest.mark.db_test
@@ -163,6 +163,69 @@ class TestXComsSetEndpoint:
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id,
dag_id=ti.dag_id).one_or_none()
assert task_map is None, "Should not be mapped"
+ @pytest.mark.parametrize(
+ "orig_value, ser_value, deser_value",
+ [
+ pytest.param(1, 1, 1, id="int"),
+ pytest.param(1.0, 1.0, 1.0, id="float"),
+ pytest.param("string", "string", "string", id="str"),
+ pytest.param(True, True, True, id="bool"),
+ pytest.param({"key": "value"}, {"key": "value"}, {"key": "value"},
id="dict"),
+ pytest.param([1, 2], [1, 2], [1, 2], id="list"),
+ pytest.param(
+ (1, 2),
+ # Client serializes tuple as encoded list, send the encoded
list to the API
+ {"__classname__": "builtins.tuple", "__data__": [1, 2],
"__version__": 1},
+ # The API will send the encoded list to the DB and sends the
same encoded list back
+ # during the response to the client as it is the clients
responsibility to
+ # serialize it into a JSON object & deserialize value into a
native object.
+ {"__classname__": "builtins.tuple", "__data__": [1, 2],
"__version__": 1},
+ id="tuple",
+ ),
+ ],
+ )
+ def test_xcom_round_trip(self, client, create_task_instance, session,
orig_value, ser_value, deser_value):
+ """
+ Test that deserialization works when XCom values are stored directly
in the DB with API Server.
+
+ This tests the case where the XCom value is stored from the Task API
where the value is serialized
+ via Client SDK into JSON object and passed via the API Server to the
DB. It by-passes
+ the XComModel.serialize_value and stores valid Python JSON compatible
objects to DB.
+
+ This test is to ensure that the deserialization works correctly in
this case as well as
+ checks that the value is stored correctly before it hits the API.
+ """
+
+ ti = create_task_instance()
+ session.commit()
+
+ # Serialize the value to simulate the client SDK
+ value = serialize(orig_value)
+
+ # Test that the value is serialized correctly
+ assert value == ser_value
+
+ response = client.post(
+ f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
+ json=value,
+ )
+
+ assert response.status_code == 201
+
+ stored_value = XComModel.get_many(
+ key="xcom_1",
+ dag_ids=ti.dag_id,
+ task_ids=ti.task_id,
+ run_id=ti.run_id,
+ session=session,
+ ).first()
+ deserialized_value = XComModel.deserialize_value(stored_value)
+
+ assert deserialized_value == deser_value
+
+ # Ensure that the deserialized value on the client side is the same as
the original value
+ assert deserialize(deserialized_value) == orig_value
+
def test_xcom_set_mapped(self, client, create_task_instance, session):
ti = create_task_instance()
session.commit()
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index c01c6a88c05..ce407d1fc69 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -370,3 +370,54 @@ class TestXComClear:
session=session,
)
assert session.query(XComModel).count() == 1
+
+
+class TestXComRoundTrip:
+ @pytest.mark.parametrize(
+ "value, expected_value",
+ [
+ pytest.param(1, 1, id="int"),
+ pytest.param(1.0, 1.0, id="float"),
+ pytest.param("string", "string", id="str"),
+ pytest.param(True, True, id="bool"),
+ pytest.param({"key": "value"}, {"key": "value"}, id="dict"),
+ pytest.param([1, 2, 3], [1, 2, 3], id="list"),
+ pytest.param((1, 2, 3), (1, 2, 3), id="tuple"), # tuple is
preserved
+ pytest.param(None, None, id="none"),
+ ],
+ )
+ def test_xcom_round_trip(self, value, expected_value,
push_simple_json_xcom, task_instance, session):
+ """Test that XComModel serialization and deserialization work as
expected."""
+ push_simple_json_xcom(ti=task_instance, key="xcom_1", value=value)
+
+ stored_value = XComModel.get_many(
+ key="xcom_1",
+ dag_ids=task_instance.dag_id,
+ task_ids=task_instance.task_id,
+ run_id=task_instance.run_id,
+ session=session,
+ ).first()
+ deserialized_value = XComModel.deserialize_value(stored_value)
+
+ assert deserialized_value == expected_value
+
+ @pytest.mark.parametrize(
+ "value, expected_value",
+ [
+ pytest.param(1, 1, id="int"),
+ pytest.param(1.0, 1.0, id="float"),
+ pytest.param("string", "string", id="str"),
+ pytest.param(True, True, id="bool"),
+ pytest.param({"key": "value"}, {"key": "value"}, id="dict"),
+ pytest.param([1, 2, 3], [1, 2, 3], id="list"),
+ pytest.param((1, 2, 3), (1, 2, 3), id="tuple"), # tuple is
preserved
+ pytest.param(None, None, id="none"),
+ ],
+ )
+ def test_xcom_deser_fallback(self, value, expected_value):
+ """Test fallback in deserialization."""
+
+ mock_xcom = MagicMock(value=value)
+ deserialized_value = XComModel.deserialize_value(mock_xcom)
+
+ assert deserialized_value == expected_value