This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch add-taskmapping-to-setxcom-execapi
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 59b89800ea5076548f4b8ee205268ef2dcebaff5
Author: Ash Berlin-Taylor <ash_git...@firemirror.com>
AuthorDate: Fri Jan 31 14:06:54 2025 +0000

    Create TaskMap rows when pushing XCom values from the Task Execution 
Interface
    
    This is needed so that we can have the scheduler create the expand TIs for 
the
    downstream tasks.
    
    Note, that we don't enforce that this is set when it needs to be on the 
server
    side, as the only way for us to know if we need to or not (and conversely,
    when it's set when it _doesn't_ need to, which is no effect other than
    creating a small DB row that nothing will ever read) as doing that would
    involve loading the serialized DAG to walk the structure which is a 
relatively
    expensive operation.
    
    We could improve that by "pre-computing" some of the info on what tasks are
    actually mapped or not before serialization so we wouldn't have to walk the
    task groups to find out, but that wouldn't do anything about the need to 
load
    the serialized DAG which is the expensive part.
    
    If this turns out to be a problem we can revisit the decision not to enforce
    this later.
---
 airflow/api_fastapi/execution_api/routes/xcoms.py  | 57 ++++++++++++++++----
 .../src/airflow/sdk/definitions/mappedoperator.py  | 17 ++++++
 task_sdk/src/airflow/sdk/exceptions.py             | 22 +++++++-
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  1 +
 .../src/airflow/sdk/execution_time/task_runner.py  | 62 +++++++++++++++-------
 .../api_fastapi/execution_api/routes/test_xcoms.py | 42 +++++++++++++--
 6 files changed, 168 insertions(+), 33 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow/api_fastapi/execution_api/routes/xcoms.py
index faacd543fca..8ce2d74be18 100644
--- a/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -28,6 +28,7 @@ from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api import deps
 from airflow.api_fastapi.execution_api.datamodels.token import TIToken
 from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
+from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import BaseXCom
 
 # TODO: Add dependency on JWT token
@@ -55,7 +56,7 @@ def get_xcom(
     map_index: Annotated[int, Query()] = -1,
 ) -> XComResponse:
     """Get an Airflow XCom from database - not other XCom Backends."""
-    if not has_xcom_access(key, token):
+    if not has_xcom_access(dag_id, run_id, task_id, key, token):
         raise HTTPException(
             status_code=status.HTTP_403_FORBIDDEN,
             detail={
@@ -104,6 +105,8 @@ def get_xcom(
     return XComResponse(key=key, value=xcom_value)
 
 
+# TODO:once we have JWT tokens, then remove the dag/run/task ids from the URL 
and just use the info in the
+# token
 @router.post(
     "/{dag_id}/{run_id}/{task_id}/{key}",
     status_code=status.HTTP_201_CREATED,
@@ -139,8 +142,23 @@ def set_xcom(
     token: deps.TokenDep,
     session: SessionDep,
     map_index: Annotated[int, Query()] = -1,
+    mapped_length: Annotated[
+        int | None, Query(description="Number of mapped tasks this value 
expands into")
+    ] = None,
 ):
     """Set an Airflow XCom."""
+    from airflow.configuration import conf
+
+    if not has_xcom_access(dag_id, run_id, task_id, key, token, write=True):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail={
+                "reason": "access_denied",
+                "message": f"Task does not have access to set XCom key 
'{key}'",
+            },
+        )
+
+    # TODO: This is in-efficient. We json.loads it here for BaseXCom.set to 
then json.dump it!
     try:
         json.loads(value)
     except json.JSONDecodeError:
@@ -152,14 +170,30 @@ def set_xcom(
             },
         )
 
-    if not has_xcom_access(key, token):
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail={
-                "reason": "access_denied",
-                "message": f"Task does not have access to set XCom key 
'{key}'",
-            },
+    if mapped_length is not None:
+        task_map = TaskMap(
+            dag_id=dag_id,
+            task_id=task_id,
+            run_id=run_id,
+            map_index=map_index,
+            length=mapped_length,
+            keys=None,
         )
+        max_map_length = conf.getint("core", "max_map_length", fallback=1024)
+        if task_map.length > max_map_length:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "reason": "unmappable_return_value_length",
+                    "message": "pushed value is too large to map as a 
downstream's dependency",
+                },
+            )
+        session.add(task_map)
+
+    # else:
+    # TODO: Can/should we check if a client _hasn't_ provided this for an 
upstream of a mapped task? That
+    # means loading the serialized dag and that seems like a relatively costly 
operation for minimal benefit
+    # (the mapped task would fail in a moment as it can't be expanded anyway.)
 
     # We use `BaseXCom.set` to set XComs directly to the database, bypassing 
the XCom Backend.
     try:
@@ -184,13 +218,16 @@ def set_xcom(
     return {"message": "XCom successfully set"}
 
 
-def has_xcom_access(xcom_key: str, token: TIToken) -> bool:
+def has_xcom_access(
+    dag_id: str, run_id: str, task_id: str, xcom_key: str, token: TIToken, 
write: bool = False
+) -> bool:
     """Check if the task has access to the XCom."""
     # TODO: Placeholder for actual implementation
 
     ti_key = token.ti_key
     log.debug(
-        "Checking access for task instance with key '%s' to XCom '%s'",
+        "Checking %s XCom access for xcom from TaskInstance with key '%s' to 
XCom '%s'",
+        "write" if write else "read",
         ti_key,
         xcom_key,
     )
diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
index 13640053424..00bd2ab8ab2 100644
--- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -78,6 +78,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.param import ParamsDict
     from airflow.sdk.types import Operator
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+    from airflow.typing_compat import TypeGuard
     from airflow.utils.context import Context
     from airflow.utils.operator_resources import Resources
     from airflow.utils.task_group import TaskGroup
@@ -136,6 +137,22 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
             ensure_xcomarg_return_value(v)
 
 
+def is_mappable_value(value: Any) -> TypeGuard[Collection]:
+    """
+    Whether a value can be used for task mapping.
+
+    We only allow collections with guaranteed ordering, but exclude character
+    sequences since that's usually not what users would expect to be mappable.
+
+    :meta private:
+    """
+    if not isinstance(value, (Sequence, dict)):
+        return False
+    if isinstance(value, (bytearray, bytes, str)):
+        return False
+    return True
+
+
 @attrs.define(kw_only=True, repr=False)
 class OperatorPartial:
     """
diff --git a/task_sdk/src/airflow/sdk/exceptions.py 
b/task_sdk/src/airflow/sdk/exceptions.py
index c713f38eef8..4dd4ff5910a 100644
--- a/task_sdk/src/airflow/sdk/exceptions.py
+++ b/task_sdk/src/airflow/sdk/exceptions.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import enum
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
 if TYPE_CHECKING:
     from airflow.sdk.execution_time.comms import ErrorResponse
@@ -35,3 +35,23 @@ class ErrorType(enum.Enum):
     VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
     XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
     GENERIC_ERROR = "GENERIC_ERROR"
+
+
+class XComForMappingNotPushed(TypeError):
+    """Raise when a mapped downstream's dependency fails to push XCom for task 
mapping."""
+
+    def __str__(self) -> str:
+        return "did not push XCom for task mapping"
+
+
+class UnmappableXComTypePushed(TypeError):
+    """Raise when an unmappable type is pushed as a mapped downstream's 
dependency."""
+
+    def __init__(self, value: Any, *values: Any) -> None:
+        super().__init__(value, *values)
+
+    def __str__(self) -> str:
+        typename = type(self.args[0]).__qualname__
+        for arg in self.args[1:]:
+            typename = f"{typename}[{type(arg).__qualname__}]"
+        return f"unmappable return type {typename!r}"
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 93ea133f489..1f398cb8b60 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -268,6 +268,7 @@ class SetXCom(BaseModel):
     run_id: str
     task_id: str
     map_index: int | None = None
+    mapped_length: int | None = None
     type: Literal["SetXCom"] = "SetXCom"
 
 
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 8b21fcfaf48..a4beac9f4d1 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -295,25 +295,9 @@ class RuntimeTaskInstance(TaskInstance):
         Make an XCom available for tasks to pull.
 
         :param key: Key to store the value under.
-        :param value: Value to store. Only be JSON-serializable may be used 
otherwise.
+        :param value: Value to store. Only be JSON-serializable values may be 
used.
         """
-        from airflow.models.xcom import XCom
-
-        # TODO: Move XCom serialization & deserialization to Task SDK
-        #   https://github.com/apache/airflow/issues/45231
-        value = XCom.serialize_value(value)
-
-        log = structlog.get_logger(logger_name="task")
-        SUPERVISOR_COMMS.send_request(
-            log=log,
-            msg=SetXCom(
-                key=key,
-                value=value,
-                dag_id=self.dag_id,
-                task_id=self.task_id,
-                run_id=self.run_id,
-            ),
-        )
+        _xcom_push(self, key, value)
 
     def get_relevant_upstream_map_indexes(
         self, upstream: BaseOperator, ti_count: int | None, session: Any
@@ -322,6 +306,30 @@ class RuntimeTaskInstance(TaskInstance):
         return None
 
 
+def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: 
int | None = None) -> None:
+    # Private function, as we don't want to expose the ability to manually set 
`mapped_length` to SDK
+    # consumers
+    from airflow.models.xcom import XCom
+
+    # TODO: Move XCom serialization & deserialization to Task SDK
+    #   https://github.com/apache/airflow/issues/45231
+    value = XCom.serialize_value(value)
+
+    log = structlog.get_logger(logger_name="task")
+    SUPERVISOR_COMMS.send_request(
+        log=log,
+        msg=SetXCom(
+            key=key,
+            value=value,
+            dag_id=ti.dag_id,
+            task_id=ti.task_id,
+            run_id=ti.run_id,
+            map_index=ti.map_index,
+            mapped_length=mapped_length,
+        ),
+    )
+
+
 def parse(what: StartupDetails) -> RuntimeTaskInstance:
     # TODO: Task-SDK:
     # Using DagBag here is about 98% wrong, but it'll do for now
@@ -645,10 +653,25 @@ def _push_xcom_if_needed(result: Any, ti: 
RuntimeTaskInstance):
     else:
         xcom_value = None
 
-    # If the task returns a result, push an XCom containing it.
+    is_mapped = next(ti.task.iter_mapped_dependants(), None) is not None or 
ti.task.is_mapped
+
     if xcom_value is None:
+        if is_mapped:
+            # Uhoh, a downstream mapped task depends on us to push something 
to map over
+            from airflow.sdk.exceptions import XComForMappingNotPushed
+
+            raise XComForMappingNotPushed()
         return
 
+    mapped_length: int | None = None
+    if is_mapped:
+        from airflow.sdk.definitions.mappedoperator import is_mappable_value
+        from airflow.sdk.exceptions import UnmappableXComTypePushed
+
+        if not is_mappable_value(xcom_value):
+            raise UnmappableXComTypePushed(xcom_value)
+        mapped_length = len(xcom_value)
+
     # If the task has multiple outputs, push each output as a separate XCom.
     if ti.task.multiple_outputs:
         if not isinstance(xcom_value, Mapping):
@@ -666,6 +689,7 @@ def _push_xcom_if_needed(result: Any, ti: 
RuntimeTaskInstance):
 
     # TODO: Use constant for XCom return key & use serialize_value from Task 
SDK
     ti.xcom_push("return_value", result)
+    _xcom_push(ti, "return_value", result, mapped_length=mapped_length)
 
 
 def finalize(log: Logger): ...
diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py 
b/tests/api_fastapi/execution_api/routes/test_xcoms.py
index 6347db9b6db..3232622b3ac 100644
--- a/tests/api_fastapi/execution_api/routes/test_xcoms.py
+++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py
@@ -17,12 +17,15 @@
 
 from __future__ import annotations
 
+import contextlib
 from unittest import mock
 
+import httpx
 import pytest
 
 from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
 from airflow.models.dagrun import DagRun
+from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCom
 from airflow.utils.session import create_session
 
@@ -114,12 +117,45 @@ class TestXComsSetEndpoint:
 
         xcom = session.query(XCom).filter_by(task_id=ti.task_id, 
dag_id=ti.dag_id, key="xcom_1").first()
         assert xcom.value == expected_value
+        task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, 
dag_id=ti.dag_id).one_or_none()
+        assert task_map is None, "Should not be mapped"
 
     @pytest.mark.parametrize(
-        "value",
-        ["value1", {"key2": "value2"}, ["value1"]],
+        ("length", "err_context"),
+        [
+            pytest.param(
+                20,
+                contextlib.nullcontext(),
+                id="20-success",
+            ),
+            pytest.param(
+                2000,
+                pytest.raises(httpx.HTTPStatusError),
+                id="2000-too-long",
+            ),
+        ],
     )
-    def test_xcom_set_invalid_json(self, client, create_task_instance, value):
+    def test_xcom_set_downstream_of_mapped(self, client, create_task_instance, 
session, length, err_context):
+        """
+        Test that XCom value is set correctly. The value is passed as a JSON 
string in the request body.
+        XCom.set then uses json.dumps to serialize it and store the value in 
the database.
+        This is done so that Task SDK in multiple languages can use the same 
API to set XCom values.
+        """
+        ti = create_task_instance()
+        session.commit()
+
+        with err_context:
+            response = client.post(
+                
f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
+                json='"valid json"',
+                params={"mapped_length": length},
+            )
+            response.raise_for_status()
+
+            task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, 
dag_id=ti.dag_id).one_or_none()
+            assert task_map.length == length
+
+    def test_xcom_set_invalid_json(self, client):
         response = client.post(
             "/execution/xcoms/dag/runid/task/xcom_1",
             json="invalid_json",

Reply via email to