This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2412792c916 AIP-72: Adding support to Set an XCom from Task SDK
(#44605)
2412792c916 is described below
commit 2412792c916b0c953314c3bf6499c1d02032b0f2
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 6 22:35:50 2024 +0530
AIP-72: Adding support to Set an XCom from Task SDK (#44605)
---
task_sdk/src/airflow/sdk/api/client.py | 17 ++++
task_sdk/src/airflow/sdk/execution_time/comms.py | 36 ++++++-
.../src/airflow/sdk/execution_time/supervisor.py | 4 +
task_sdk/tests/api/test_client.py | 113 ++++++++++++++++++++-
task_sdk/tests/execution_time/test_supervisor.py | 72 +++++++++++--
5 files changed, 230 insertions(+), 12 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index df0499dac39..5de5c7a8d90 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -176,9 +176,26 @@ class XComOperations:
def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index:
int = -1) -> XComResponse:
"""Get a XCom value from the API server."""
+ # TODO: check if we need to use map_index as params in the uri
+ # ref:
https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}",
params={"map_index": map_index})
return XComResponse.model_validate_json(resp.read())
+ def set(
+ self, dag_id: str, run_id: str, task_id: str, key: str, value,
map_index: int | None = None
+ ) -> dict[str, bool]:
+ """Set a XCom value via the API server."""
+ # TODO: check if we need to use map_index as params in the uri
+ # ref:
https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
+ params = {}
+ if map_index:
+ params = {"map_index": map_index}
+ self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}",
params=params, json=value)
+ # Any error from the server will anyway be propagated down to the
supervisor,
+ # so we choose to send a generic response to the supervisor over the
server response to
+ # decouple from the server response string
+ return {"ok": True}
+
class BearerAuth(httpx.Auth):
def __init__(self, token: str):
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index c1725e42a4f..65507210041 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -46,7 +46,8 @@ from __future__ import annotations
from datetime import datetime
from typing import Annotated, Literal, Union
-from pydantic import BaseModel, ConfigDict, Field
+from fastapi import Body
+from pydantic import BaseModel, ConfigDict, Field, JsonValue
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
@@ -121,6 +122,37 @@ class GetXCom(BaseModel):
type: Literal["GetXCom"] = "GetXCom"
+class SetXCom(BaseModel):
+ key: str
+ value: Annotated[
+ # JsonValue can handle non JSON stringified dicts, lists and strings,
which is better
+ # for the task intuitibe to send to the supervisor
+ JsonValue,
+ Body(
+ description="A JSON-formatted string representing the value to set
for the XCom.",
+ openapi_examples={
+ "simple_value": {
+ "summary": "Simple value",
+ "value": "value1",
+ },
+ "dict_value": {
+ "summary": "Dictionary value",
+ "value": {"key2": "value2"},
+ },
+ "list_value": {
+ "summary": "List value",
+ "value": ["value1"],
+ },
+ },
+ ),
+ ]
+ dag_id: str
+ run_id: str
+ task_id: str
+ map_index: int | None = None
+ type: Literal["SetXCom"] = "SetXCom"
+
+
class GetConnection(BaseModel):
conn_id: str
type: Literal["GetConnection"] = "GetConnection"
@@ -139,6 +171,6 @@ class PutVariable(BaseModel):
ToSupervisor = Annotated[
- Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask,
PutVariable],
+ Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask,
PutVariable, SetXCom],
Field(discriminator="type"),
]
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 9e1e36c8057..3bc714e6415 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -55,6 +55,7 @@ from airflow.sdk.execution_time.comms import (
GetVariable,
GetXCom,
PutVariable,
+ SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
@@ -675,6 +676,9 @@ class WatchedSubprocess:
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
+ elif isinstance(msg, SetXCom):
+ self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.value, msg.map_index)
+ resp = None
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
resp = None
diff --git a/task_sdk/tests/api/test_client.py
b/task_sdk/tests/api/test_client.py
index c3c93c8be41..900d1c90de5 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -17,11 +17,13 @@
from __future__ import annotations
+import json
+
import httpx
import pytest
from airflow.sdk.api.client import Client, RemoteValidationError,
ServerResponseError
-from airflow.sdk.api.datamodels._generated import VariableResponse
+from airflow.sdk.api.datamodels._generated import VariableResponse,
XComResponse
class TestClient:
@@ -148,3 +150,112 @@ class TestVariableOperations:
result = client.variables.set(key="test_key", value="test_value",
description="test_description")
assert result == {"ok": True}
+
+
+class TestXCOMOperations:
+ """
+ Test that the XComOperations class works as expected. While the operations
are simple, it
+ still catches the basic functionality of the client for xcoms including
endpoint and
+ response parsing.
+ """
+
+ def test_xcom_get_success(self):
+ # Simulate a successful response from the server when getting an xcom
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
+ return httpx.Response(
+ status_code=201,
+ json={"key": "test_key", "value": "test_value"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.xcoms.get(
+ dag_id="dag_id",
+ run_id="run_id",
+ task_id="task_id",
+ key="key",
+ )
+ assert isinstance(result, XComResponse)
+ assert result.key == "test_key"
+ assert result.value == "test_value"
+
+ def test_xcom_get_success_with_map_index(self):
+ # Simulate a successful response from the server when getting an xcom
with map_index passed
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if (
+ request.url.path == "/xcoms/dag_id/run_id/task_id/key"
+ and request.url.params.get("map_index") == "2"
+ ):
+ return httpx.Response(
+ status_code=201,
+ json={"key": "test_key", "value": "test_value"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.xcoms.get(
+ dag_id="dag_id",
+ run_id="run_id",
+ task_id="task_id",
+ key="key",
+ map_index=2,
+ )
+ assert isinstance(result, XComResponse)
+ assert result.key == "test_key"
+ assert result.value == "test_value"
+
+ @pytest.mark.parametrize(
+ "values",
+ [
+ pytest.param("value1", id="string-value"),
+ pytest.param({"key1": "value1"}, id="dict-value"),
+ pytest.param(["value1", "value2"], id="list-value"),
+ pytest.param({"key": "test_key", "value": {"key2": "value2"}},
id="nested-dict-value"),
+ ],
+ )
+ def test_xcom_set_success(self, values):
+ # Simulate a successful response from the server when setting an xcom
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
+ assert json.loads(request.read()) == values
+ return httpx.Response(
+ status_code=201,
+ json={"message": "XCom successfully set"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.xcoms.set(
+ dag_id="dag_id",
+ run_id="run_id",
+ task_id="task_id",
+ key="key",
+ value=values,
+ )
+ assert result == {"ok": True}
+
+ def test_xcom_set_with_map_index(self):
+ # Simulate a successful response from the server when setting an xcom
with map_index passed
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if (
+ request.url.path == "/xcoms/dag_id/run_id/task_id/key"
+ and request.url.params.get("map_index") == "2"
+ ):
+ assert json.loads(request.read()) == "value1"
+ return httpx.Response(
+ status_code=201,
+ json={"message": "XCom successfully set"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.xcoms.set(
+ dag_id="dag_id",
+ run_id="run_id",
+ task_id="task_id",
+ key="key",
+ value="value1",
+ map_index=2,
+ )
+ assert result == {"ok": True}
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 406480b653d..6bd3047de67 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -44,6 +44,7 @@ from airflow.sdk.execution_time.comms import (
GetVariable,
GetXCom,
PutVariable,
+ SetXCom,
VariableResult,
XComResult,
)
@@ -788,6 +789,22 @@ class TestHandleRequest:
VariableResult(key="test_key", value="test_value"),
id="get_variable",
),
+ pytest.param(
+ PutVariable(key="test_key", value="test_value",
description="test_description"),
+ b"",
+ "variables.set",
+ ("test_key", "test_value", "test_description"),
+ {"ok": True},
+ id="set_variable",
+ ),
+ pytest.param(
+ DeferTask(next_method="execute_callback",
classpath="my-classpath"),
+ b"",
+ "task_instances.defer",
+ (TI_ID, DeferTask(next_method="execute_callback",
classpath="my-classpath")),
+ "",
+ id="patch_task_instance_to_deferred",
+ ),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value"}\n',
@@ -797,20 +814,57 @@ class TestHandleRequest:
id="get_xcom",
),
pytest.param(
- DeferTask(next_method="execute_callback",
classpath="my-classpath"),
+ GetXCom(
+ dag_id="test_dag", run_id="test_run", task_id="test_task",
key="test_key", map_index=2
+ ),
+ b'{"key":"test_key","value":"test_value"}\n',
+ "xcoms.get",
+ ("test_dag", "test_run", "test_task", "test_key", 2),
+ XComResult(key="test_key", value="test_value"),
+ id="get_xcom_map_index",
+ ),
+ pytest.param(
+ SetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ value='{"key": "test_key", "value": {"key2": "value2"}}',
+ ),
b"",
- "task_instances.defer",
- (TI_ID, DeferTask(next_method="execute_callback",
classpath="my-classpath")),
- "",
- id="patch_task_instance_to_deferred",
+ "xcoms.set",
+ (
+ "test_dag",
+ "test_run",
+ "test_task",
+ "test_key",
+ '{"key": "test_key", "value": {"key2": "value2"}}',
+ None,
+ ),
+ {"ok": True},
+ id="set_xcom",
),
pytest.param(
- PutVariable(key="test_key", value="test_value",
description="test_description"),
+ SetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ value='{"key": "test_key", "value": {"key2": "value2"}}',
+ map_index=2,
+ ),
b"",
- "variables.set",
- ("test_key", "test_value", "test_description"),
+ "xcoms.set",
+ (
+ "test_dag",
+ "test_run",
+ "test_task",
+ "test_key",
+ '{"key": "test_key", "value": {"key2": "value2"}}',
+ 2,
+ ),
{"ok": True},
- id="set_variable",
+ id="set_xcom_with_map_index",
),
],
)