This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch add-taskmapping-to-setxcom-execapi in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 59b89800ea5076548f4b8ee205268ef2dcebaff5 Author: Ash Berlin-Taylor <ash_git...@firemirror.com> AuthorDate: Fri Jan 31 14:06:54 2025 +0000 Create TaskMap rows when pushing XCom values from the Task Execution Interface This is needed so that we can have the scheduler create the expand TIs for the downstream tasks. Note, that we don't enforce that this is set when it needs to be on the server side, as the only way for us to know if we need to or not (and conversely, when it's set when it _doesn't_ need to, which is no effect other than creating a small DB row that nothing will ever read) as doing that would involve loading the serialized DAG to walk the structure which is a relatively expensive operation. We could improve that by "pre-computing" some of the info on what tasks are actually mapped or not before serialization so we wouldn't have to walk the task groups to find out, but that wouldn't do anything about the need to load the serialized DAG which is the expensive part. If this turns out to be a problem we can revisit the decision not to enforce this later. --- airflow/api_fastapi/execution_api/routes/xcoms.py | 57 ++++++++++++++++---- .../src/airflow/sdk/definitions/mappedoperator.py | 17 ++++++ task_sdk/src/airflow/sdk/exceptions.py | 22 +++++++- task_sdk/src/airflow/sdk/execution_time/comms.py | 1 + .../src/airflow/sdk/execution_time/task_runner.py | 62 +++++++++++++++------- .../api_fastapi/execution_api/routes/test_xcoms.py | 42 +++++++++++++-- 6 files changed, 168 insertions(+), 33 deletions(-) diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index faacd543fca..8ce2d74be18 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -28,6 +28,7 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api import deps from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse +from airflow.models.taskmap import TaskMap from airflow.models.xcom import BaseXCom # TODO: Add dependency on JWT token @@ -55,7 +56,7 @@ def get_xcom( map_index: Annotated[int, Query()] = -1, ) -> XComResponse: """Get an Airflow XCom from database - not other XCom Backends.""" - if not has_xcom_access(key, token): + if not has_xcom_access(dag_id, run_id, task_id, key, token): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ @@ -104,6 +105,8 @@ def get_xcom( return XComResponse(key=key, value=xcom_value) +# TODO:once we have JWT tokens, then remove the dag/run/task ids from the URL and just use the info in the +# token @router.post( "/{dag_id}/{run_id}/{task_id}/{key}", status_code=status.HTTP_201_CREATED, @@ -139,8 +142,23 @@ def set_xcom( token: deps.TokenDep, session: SessionDep, map_index: Annotated[int, Query()] = -1, + mapped_length: Annotated[ + int | None, Query(description="Number of mapped tasks this value expands into") + ] = None, ): """Set an Airflow XCom.""" + from airflow.configuration import conf + + if not has_xcom_access(dag_id, run_id, task_id, key, token, write=True): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "reason": "access_denied", + "message": f"Task does not have access to set XCom key '{key}'", + }, + ) + + # TODO: This is in-efficient. We json.loads it here for BaseXCom.set to then json.dump it! try: json.loads(value) except json.JSONDecodeError: @@ -152,14 +170,30 @@ def set_xcom( }, ) - if not has_xcom_access(key, token): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail={ - "reason": "access_denied", - "message": f"Task does not have access to set XCom key '{key}'", - }, + if mapped_length is not None: + task_map = TaskMap( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + length=mapped_length, + keys=None, ) + max_map_length = conf.getint("core", "max_map_length", fallback=1024) + if task_map.length > max_map_length: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "unmappable_return_value_length", + "message": "pushed value is too large to map as a downstream's dependency", + }, + ) + session.add(task_map) + + # else: + # TODO: Can/should we check if a client _hasn't_ provided this for an upstream of a mapped task? That + # means loading the serialized dag and that seems like a relatively costly operation for minimal benefit + # (the mapped task would fail in a moment as it can't be expanded anyway.) # We use `BaseXCom.set` to set XComs directly to the database, bypassing the XCom Backend. try: @@ -184,13 +218,16 @@ def set_xcom( return {"message": "XCom successfully set"} -def has_xcom_access(xcom_key: str, token: TIToken) -> bool: +def has_xcom_access( + dag_id: str, run_id: str, task_id: str, xcom_key: str, token: TIToken, write: bool = False +) -> bool: """Check if the task has access to the XCom.""" # TODO: Placeholder for actual implementation ti_key = token.ti_key log.debug( - "Checking access for task instance with key '%s' to XCom '%s'", + "Checking %s XCom access for xcom from TaskInstance with key '%s' to XCom '%s'", + "write" if write else "read", ti_key, xcom_key, ) diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index 13640053424..00bd2ab8ab2 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup @@ -136,6 +137,22 @@ def ensure_xcomarg_return_value(arg: Any) -> None: ensure_xcomarg_return_value(v) +def is_mappable_value(value: Any) -> TypeGuard[Collection]: + """ + Whether a value can be used for task mapping. + + We only allow collections with guaranteed ordering, but exclude character + sequences since that's usually not what users would expect to be mappable. + + :meta private: + """ + if not isinstance(value, (Sequence, dict)): + return False + if isinstance(value, (bytearray, bytes, str)): + return False + return True + + @attrs.define(kw_only=True, repr=False) class OperatorPartial: """ diff --git a/task_sdk/src/airflow/sdk/exceptions.py b/task_sdk/src/airflow/sdk/exceptions.py index c713f38eef8..4dd4ff5910a 100644 --- a/task_sdk/src/airflow/sdk/exceptions.py +++ b/task_sdk/src/airflow/sdk/exceptions.py @@ -18,7 +18,7 @@ from __future__ import annotations import enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from airflow.sdk.execution_time.comms import ErrorResponse @@ -35,3 +35,23 @@ class ErrorType(enum.Enum): VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND" XCOM_NOT_FOUND = "XCOM_NOT_FOUND" GENERIC_ERROR = "GENERIC_ERROR" + + +class XComForMappingNotPushed(TypeError): + """Raise when a mapped downstream's dependency fails to push XCom for task mapping.""" + + def __str__(self) -> str: + return "did not push XCom for task mapping" + + +class UnmappableXComTypePushed(TypeError): + """Raise when an unmappable type is pushed as a mapped downstream's dependency.""" + + def __init__(self, value: Any, *values: Any) -> None: + super().__init__(value, *values) + + def __str__(self) -> str: + typename = type(self.args[0]).__qualname__ + for arg in self.args[1:]: + typename = f"{typename}[{type(arg).__qualname__}]" + return f"unmappable return type {typename!r}" diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 93ea133f489..1f398cb8b60 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -268,6 +268,7 @@ class SetXCom(BaseModel): run_id: str task_id: str map_index: int | None = None + mapped_length: int | None = None type: Literal["SetXCom"] = "SetXCom" diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 8b21fcfaf48..a4beac9f4d1 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -295,25 +295,9 @@ class RuntimeTaskInstance(TaskInstance): Make an XCom available for tasks to pull. :param key: Key to store the value under. - :param value: Value to store. Only be JSON-serializable may be used otherwise. + :param value: Value to store. Only be JSON-serializable values may be used. """ - from airflow.models.xcom import XCom - - # TODO: Move XCom serialization & deserialization to Task SDK - # https://github.com/apache/airflow/issues/45231 - value = XCom.serialize_value(value) - - log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=self.dag_id, - task_id=self.task_id, - run_id=self.run_id, - ), - ) + _xcom_push(self, key, value) def get_relevant_upstream_map_indexes( self, upstream: BaseOperator, ti_count: int | None, session: Any @@ -322,6 +306,30 @@ class RuntimeTaskInstance(TaskInstance): return None +def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: + # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK + # consumers + from airflow.models.xcom import XCom + + # TODO: Move XCom serialization & deserialization to Task SDK + # https://github.com/apache/airflow/issues/45231 + value = XCom.serialize_value(value) + + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index, + mapped_length=mapped_length, + ), + ) + + def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Task-SDK: # Using DagBag here is about 98% wrong, but it'll do for now @@ -645,10 +653,25 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance): else: xcom_value = None - # If the task returns a result, push an XCom containing it. + is_mapped = next(ti.task.iter_mapped_dependants(), None) is not None or ti.task.is_mapped + if xcom_value is None: + if is_mapped: + # Uhoh, a downstream mapped task depends on us to push something to map over + from airflow.sdk.exceptions import XComForMappingNotPushed + + raise XComForMappingNotPushed() return + mapped_length: int | None = None + if is_mapped: + from airflow.sdk.definitions.mappedoperator import is_mappable_value + from airflow.sdk.exceptions import UnmappableXComTypePushed + + if not is_mappable_value(xcom_value): + raise UnmappableXComTypePushed(xcom_value) + mapped_length = len(xcom_value) + # If the task has multiple outputs, push each output as a separate XCom. if ti.task.multiple_outputs: if not isinstance(xcom_value, Mapping): @@ -666,6 +689,7 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance): # TODO: Use constant for XCom return key & use serialize_value from Task SDK ti.xcom_push("return_value", result) + _xcom_push(ti, "return_value", result, mapped_length=mapped_length) def finalize(log: Logger): ... diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index 6347db9b6db..3232622b3ac 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -17,12 +17,15 @@ from __future__ import annotations +import contextlib from unittest import mock +import httpx import pytest from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.dagrun import DagRun +from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCom from airflow.utils.session import create_session @@ -114,12 +117,45 @@ class TestXComsSetEndpoint: xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() assert xcom.value == expected_value + task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() + assert task_map is None, "Should not be mapped" @pytest.mark.parametrize( - "value", - ["value1", {"key2": "value2"}, ["value1"]], + ("length", "err_context"), + [ + pytest.param( + 20, + contextlib.nullcontext(), + id="20-success", + ), + pytest.param( + 2000, + pytest.raises(httpx.HTTPStatusError), + id="2000-too-long", + ), + ], ) - def test_xcom_set_invalid_json(self, client, create_task_instance, value): + def test_xcom_set_downstream_of_mapped(self, client, create_task_instance, session, length, err_context): + """ + Test that XCom value is set correctly. The value is passed as a JSON string in the request body. + XCom.set then uses json.dumps to serialize it and store the value in the database. + This is done so that Task SDK in multiple languages can use the same API to set XCom values. + """ + ti = create_task_instance() + session.commit() + + with err_context: + response = client.post( + f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1", + json='"valid json"', + params={"mapped_length": length}, + ) + response.raise_for_status() + + task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() + assert task_map.length == length + + def test_xcom_set_invalid_json(self, client): response = client.post( "/execution/xcoms/dag/runid/task/xcom_1", json="invalid_json",