This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2412792c916 AIP-72: Adding support to Set an XCom from Task SDK 
(#44605)
2412792c916 is described below

commit 2412792c916b0c953314c3bf6499c1d02032b0f2
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 6 22:35:50 2024 +0530

    AIP-72: Adding support to Set an XCom from Task SDK (#44605)
---
 task_sdk/src/airflow/sdk/api/client.py             |  17 ++++
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  36 ++++++-
 .../src/airflow/sdk/execution_time/supervisor.py   |   4 +
 task_sdk/tests/api/test_client.py                  | 113 ++++++++++++++++++++-
 task_sdk/tests/execution_time/test_supervisor.py   |  72 +++++++++++--
 5 files changed, 230 insertions(+), 12 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index df0499dac39..5de5c7a8d90 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -176,9 +176,26 @@ class XComOperations:
 
     def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: 
int = -1) -> XComResponse:
         """Get a XCom value from the API server."""
+        # TODO: check if we need to use map_index as params in the uri
+        # ref: 
https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
         resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", 
params={"map_index": map_index})
         return XComResponse.model_validate_json(resp.read())
 
+    def set(
+        self, dag_id: str, run_id: str, task_id: str, key: str, value, 
map_index: int | None = None
+    ) -> dict[str, bool]:
+        """Set a XCom value via the API server."""
+        # TODO: check if we need to use map_index as params in the uri
+        # ref: 
https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
+        params = {}
+        if map_index:
+            params = {"map_index": map_index}
+        self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", 
params=params, json=value)
+        # Any error from the server will anyway be propagated down to the 
supervisor,
+        # so we choose to send a generic response to the supervisor over the 
server response to
+        # decouple from the server response string
+        return {"ok": True}
+
 
 class BearerAuth(httpx.Auth):
     def __init__(self, token: str):
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index c1725e42a4f..65507210041 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -46,7 +46,8 @@ from __future__ import annotations
 from datetime import datetime
 from typing import Annotated, Literal, Union
 
-from pydantic import BaseModel, ConfigDict, Field
+from fastapi import Body
+from pydantic import BaseModel, ConfigDict, Field, JsonValue
 
 from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
@@ -121,6 +122,37 @@ class GetXCom(BaseModel):
     type: Literal["GetXCom"] = "GetXCom"
 
 
+class SetXCom(BaseModel):
+    key: str
+    value: Annotated[
+        # JsonValue can handle non JSON stringified dicts, lists and strings, 
which is better
+        # for the task intuitibe to send to the supervisor
+        JsonValue,
+        Body(
+            description="A JSON-formatted string representing the value to set 
for the XCom.",
+            openapi_examples={
+                "simple_value": {
+                    "summary": "Simple value",
+                    "value": "value1",
+                },
+                "dict_value": {
+                    "summary": "Dictionary value",
+                    "value": {"key2": "value2"},
+                },
+                "list_value": {
+                    "summary": "List value",
+                    "value": ["value1"],
+                },
+            },
+        ),
+    ]
+    dag_id: str
+    run_id: str
+    task_id: str
+    map_index: int | None = None
+    type: Literal["SetXCom"] = "SetXCom"
+
+
 class GetConnection(BaseModel):
     conn_id: str
     type: Literal["GetConnection"] = "GetConnection"
@@ -139,6 +171,6 @@ class PutVariable(BaseModel):
 
 
 ToSupervisor = Annotated[
-    Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, 
PutVariable],
+    Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, 
PutVariable, SetXCom],
     Field(discriminator="type"),
 ]
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 9e1e36c8057..3bc714e6415 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -55,6 +55,7 @@ from airflow.sdk.execution_time.comms import (
     GetVariable,
     GetXCom,
     PutVariable,
+    SetXCom,
     StartupDetails,
     TaskState,
     ToSupervisor,
@@ -675,6 +676,9 @@ class WatchedSubprocess:
                 self._terminal_state = IntermediateTIState.DEFERRED
                 self.client.task_instances.defer(self.ti_id, msg)
                 resp = None
+            elif isinstance(msg, SetXCom):
+                self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.value, msg.map_index)
+                resp = None
             elif isinstance(msg, PutVariable):
                 self.client.variables.set(msg.key, msg.value, msg.description)
                 resp = None
diff --git a/task_sdk/tests/api/test_client.py 
b/task_sdk/tests/api/test_client.py
index c3c93c8be41..900d1c90de5 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -17,11 +17,13 @@
 
 from __future__ import annotations
 
+import json
+
 import httpx
 import pytest
 
 from airflow.sdk.api.client import Client, RemoteValidationError, 
ServerResponseError
-from airflow.sdk.api.datamodels._generated import VariableResponse
+from airflow.sdk.api.datamodels._generated import VariableResponse, 
XComResponse
 
 
 class TestClient:
@@ -148,3 +150,112 @@ class TestVariableOperations:
 
         result = client.variables.set(key="test_key", value="test_value", 
description="test_description")
         assert result == {"ok": True}
+
+
+class TestXCOMOperations:
+    """
+    Test that the XComOperations class works as expected. While the operations 
are simple, it
+    still catches the basic functionality of the client for xcoms including 
endpoint and
+    response parsing.
+    """
+
+    def test_xcom_get_success(self):
+        # Simulate a successful response from the server when getting an xcom
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
+                return httpx.Response(
+                    status_code=201,
+                    json={"key": "test_key", "value": "test_value"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.xcoms.get(
+            dag_id="dag_id",
+            run_id="run_id",
+            task_id="task_id",
+            key="key",
+        )
+        assert isinstance(result, XComResponse)
+        assert result.key == "test_key"
+        assert result.value == "test_value"
+
+    def test_xcom_get_success_with_map_index(self):
+        # Simulate a successful response from the server when getting an xcom 
with map_index passed
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if (
+                request.url.path == "/xcoms/dag_id/run_id/task_id/key"
+                and request.url.params.get("map_index") == "2"
+            ):
+                return httpx.Response(
+                    status_code=201,
+                    json={"key": "test_key", "value": "test_value"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.xcoms.get(
+            dag_id="dag_id",
+            run_id="run_id",
+            task_id="task_id",
+            key="key",
+            map_index=2,
+        )
+        assert isinstance(result, XComResponse)
+        assert result.key == "test_key"
+        assert result.value == "test_value"
+
+    @pytest.mark.parametrize(
+        "values",
+        [
+            pytest.param("value1", id="string-value"),
+            pytest.param({"key1": "value1"}, id="dict-value"),
+            pytest.param(["value1", "value2"], id="list-value"),
+            pytest.param({"key": "test_key", "value": {"key2": "value2"}}, 
id="nested-dict-value"),
+        ],
+    )
+    def test_xcom_set_success(self, values):
+        # Simulate a successful response from the server when setting an xcom
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
+                assert json.loads(request.read()) == values
+                return httpx.Response(
+                    status_code=201,
+                    json={"message": "XCom successfully set"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.xcoms.set(
+            dag_id="dag_id",
+            run_id="run_id",
+            task_id="task_id",
+            key="key",
+            value=values,
+        )
+        assert result == {"ok": True}
+
+    def test_xcom_set_with_map_index(self):
+        # Simulate a successful response from the server when setting an xcom 
with map_index passed
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if (
+                request.url.path == "/xcoms/dag_id/run_id/task_id/key"
+                and request.url.params.get("map_index") == "2"
+            ):
+                assert json.loads(request.read()) == "value1"
+                return httpx.Response(
+                    status_code=201,
+                    json={"message": "XCom successfully set"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.xcoms.set(
+            dag_id="dag_id",
+            run_id="run_id",
+            task_id="task_id",
+            key="key",
+            value="value1",
+            map_index=2,
+        )
+        assert result == {"ok": True}
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 406480b653d..6bd3047de67 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -44,6 +44,7 @@ from airflow.sdk.execution_time.comms import (
     GetVariable,
     GetXCom,
     PutVariable,
+    SetXCom,
     VariableResult,
     XComResult,
 )
@@ -788,6 +789,22 @@ class TestHandleRequest:
                 VariableResult(key="test_key", value="test_value"),
                 id="get_variable",
             ),
+            pytest.param(
+                PutVariable(key="test_key", value="test_value", 
description="test_description"),
+                b"",
+                "variables.set",
+                ("test_key", "test_value", "test_description"),
+                {"ok": True},
+                id="set_variable",
+            ),
+            pytest.param(
+                DeferTask(next_method="execute_callback", 
classpath="my-classpath"),
+                b"",
+                "task_instances.defer",
+                (TI_ID, DeferTask(next_method="execute_callback", 
classpath="my-classpath")),
+                "",
+                id="patch_task_instance_to_deferred",
+            ),
             pytest.param(
                 GetXCom(dag_id="test_dag", run_id="test_run", 
task_id="test_task", key="test_key"),
                 b'{"key":"test_key","value":"test_value"}\n',
@@ -797,20 +814,57 @@ class TestHandleRequest:
                 id="get_xcom",
             ),
             pytest.param(
-                DeferTask(next_method="execute_callback", 
classpath="my-classpath"),
+                GetXCom(
+                    dag_id="test_dag", run_id="test_run", task_id="test_task", 
key="test_key", map_index=2
+                ),
+                b'{"key":"test_key","value":"test_value"}\n',
+                "xcoms.get",
+                ("test_dag", "test_run", "test_task", "test_key", 2),
+                XComResult(key="test_key", value="test_value"),
+                id="get_xcom_map_index",
+            ),
+            pytest.param(
+                SetXCom(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    key="test_key",
+                    value='{"key": "test_key", "value": {"key2": "value2"}}',
+                ),
                 b"",
-                "task_instances.defer",
-                (TI_ID, DeferTask(next_method="execute_callback", 
classpath="my-classpath")),
-                "",
-                id="patch_task_instance_to_deferred",
+                "xcoms.set",
+                (
+                    "test_dag",
+                    "test_run",
+                    "test_task",
+                    "test_key",
+                    '{"key": "test_key", "value": {"key2": "value2"}}',
+                    None,
+                ),
+                {"ok": True},
+                id="set_xcom",
             ),
             pytest.param(
-                PutVariable(key="test_key", value="test_value", 
description="test_description"),
+                SetXCom(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    key="test_key",
+                    value='{"key": "test_key", "value": {"key2": "value2"}}',
+                    map_index=2,
+                ),
                 b"",
-                "variables.set",
-                ("test_key", "test_value", "test_description"),
+                "xcoms.set",
+                (
+                    "test_dag",
+                    "test_run",
+                    "test_task",
+                    "test_key",
+                    '{"key": "test_key", "value": {"key2": "value2"}}',
+                    2,
+                ),
                 {"ok": True},
-                id="set_variable",
+                id="set_xcom_with_map_index",
             ),
         ],
     )

Reply via email to