This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch resume-deferred-operator-tasksdk in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b890aa32bbd1ea584bb2e0b57ef3ed3ee2d54c8b Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Mon Feb 24 16:50:20 2025 +0000 Correctly support resuming tasks after triggers There was a number of issues here that prevented us from correctly resuming a task after it's trigger fired. First off, only ever ran the execute method and we didn't respect next_method. So we added that info the the TI context we send in the server in response to the `.../run` endpoint. Once we'd fixed that, we then ran into the net problem that we were incorrectly setting the same value for trigger_kwargs (which are the kwargs we pass to the trigger constructor) and to the next_kwargs (which are the kwargs to use when resuming the task) -- they needed to be different. This involved adding the new field (`kwargs`) onto the TIDeferredStatePayload. The next complication after that was the "ExtendedJSON" type on the next_kwargs column of TI: this is a type decorator that automatically applies the BaseSerialization encode/decode step (__var and __type etc). The problem with that is that we need to do the serialization on the client in order to send a JSON HTTP request, so we don't want to encode _again_ on the server ideally. I was able to do that easily on the write/update side but not so easily on the read side -- there I left a comment and for now we will hae SQLA decode it for us, and then we have to encode it again. Not the best, but not a disaster either. The other change I did here was to have the DeferTask automatically apply the serde encoding when serializing, just so that there are fewer places in the code that need to be aware of that detail (So the Task subprocess will encode it before making the request toe the Supervisor, and in the Supervisor it will be kept encoded and passed along as is to it's HTTP request). This means you can once again pass datetime objects to a trigger "natively", not only strings. For consistency with the user facing code I renamed `next_method` on `DeferTask` message to `method_name`. I'm not sure this really makes sense on in the API request to the API server though. --- .../execution_api/datamodels/taskinstance.py | 31 +++++++--- .../execution_api/routes/task_instances.py | 53 ++++++++++++++-- airflow/models/baseoperator.py | 45 +------------- .../src/airflow/sdk/api/datamodels/_generated.py | 7 ++- .../src/airflow/sdk/definitions/baseoperator.py | 71 +++++++++++++++++++++- task_sdk/src/airflow/sdk/execution_time/comms.py | 17 +++++- .../src/airflow/sdk/execution_time/supervisor.py | 8 ++- .../src/airflow/sdk/execution_time/task_runner.py | 22 +++++-- task_sdk/tests/api/test_client.py | 23 ++++--- task_sdk/tests/dags/super_basic_deferred_run.py | 2 +- task_sdk/tests/execution_time/test_supervisor.py | 10 ++- task_sdk/tests/execution_time/test_task_runner.py | 41 +++++++++++-- .../execution_api/routes/test_task_instances.py | 64 ++++++++++++++++--- 13 files changed, 299 insertions(+), 95 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index a176a2f9d5b..0bc01f49c5c 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -28,7 +28,6 @@ from pydantic import ( Tag, TypeAdapter, WithJsonSchema, - field_validator, ) from airflow.api_fastapi.common.types import UtcDateTime @@ -122,15 +121,22 @@ class TIDeferredStatePayload(StrictBaseModel): ), ] classpath: str - trigger_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)] - next_method: str + trigger_kwargs: Annotated[dict[str, Any] | str, Field(default_factory=dict)] + """ + Kwargs to pass to the trigger constructor, either a plain dict or an ecnrypted string. + + Both forms will be passed along to the trigger, the server will not handle either. + """ + trigger_timeout: timedelta | None = None + next_method: str + """The name of themethod on the operator to call in the worker after the trigger has fired.""" + next_kwargs: Annotated[dict[str, Any] | str, Field(default_factory=dict)] + """ + Kwargs to pass to the above method, either a plain dict or an ecnrypted string. - @field_validator("trigger_kwargs") - def validate_moment(cls, v): - if "moment" in v: - v["moment"] = AwareDatetimeAdapter.validate_strings(v["moment"]) - return v + Both forms will be passed along to the TaskSDK upon resume, the server will not handle either. + """ class TIRescheduleStatePayload(StrictBaseModel): @@ -252,6 +258,15 @@ class TIRunContext(BaseModel): upstream_map_indexes: dict[str, int] | None = None + next_method: str | None = None + """Method to call. Set when task resumes from a trigger.""" + next_kwargs: dict[str, Any] | str | None = None + """ + Args to pass to ``next_method``. + + Can either be a "decorated" dict, or a string encrypted with the shared Fernet key. + """ + class PrevSuccessfulDagRunResponse(BaseModel): """Schema for response with previous successful DagRun information for Task Template Context.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 87cc547a9f0..6a403c1957a 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -17,6 +17,7 @@ from __future__ import annotations +import json import logging from typing import Annotated from uuid import UUID @@ -78,6 +79,9 @@ def ti_run( # We only use UUID above for validation purposes ti_id_str = str(task_instance_id) + from sqlalchemy.sql import column + from sqlalchemy.types import JSON + old = ( select( TI.state, @@ -88,14 +92,28 @@ def ti_run( TI.next_method, TI.try_number, TI.max_tries, + TI.next_method, + # This selects the raw JSON value, by-passing the deserialization -- we want that to happen on the + # client + column("next_kwargs", JSON), ) + .select_from(TI) .where(TI.id == ti_id_str) .with_for_update() ) try: - (previous_state, dag_id, run_id, task_id, map_index, next_method, try_number, max_tries) = ( - session.execute(old).one() - ) + ( + previous_state, + dag_id, + run_id, + task_id, + map_index, + next_method, + try_number, + max_tries, + next_method, + next_kwargs, + ) = session.execute(old).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) raise HTTPException( @@ -195,7 +213,7 @@ def ti_run( or 0 ) - return TIRunContext( + context = TIRunContext( dag_run=dr, task_reschedule_count=task_reschedule_count, max_tries=max_tries, @@ -203,6 +221,13 @@ def ti_run( variables=[], connections=[], ) + + # Only set if they are non-null + if next_method: + context.next_method = next_method + context.next_kwargs = next_kwargs + + return context except SQLAlchemyError as e: log.error("Error marking Task Instance state as running: %s", e) raise HTTPException( @@ -290,10 +315,17 @@ def ti_update_state( if ti_patch_payload.trigger_timeout is not None: timeout = timezone.utcnow() + ti_patch_payload.trigger_timeout + trigger_kwargs = ti_patch_payload.trigger_kwargs + if not isinstance(trigger_kwargs, str): + # If it's passed as a string, assume the client encrypted it, otherwise assume it doesn't need to + # be. Just JSON serialize it + trigger_kwargs = json.dumps(trigger_kwargs) + trigger_row = Trigger( classpath=ti_patch_payload.classpath, - kwargs=ti_patch_payload.trigger_kwargs, + kwargs={}, ) + trigger_row.encrypted_kwargs = trigger_kwargs session.add(trigger_row) session.flush() @@ -301,11 +333,20 @@ def ti_update_state( # either get it from the serialised DAG or get it from the API query = update(TI).where(TI.id == ti_id_str) + + # This is slightly inefficient as we deserialize it to then right again serialize it in the sqla + # TypeAdapter. + next_kwargs = None + if ti_patch_payload.next_kwargs: + from airflow.serialization.serialized_objects import BaseSerialization + + next_kwargs = BaseSerialization.deserialize(ti_patch_payload.next_kwargs) + query = query.values( state=State.DEFERRED, trigger_id=trigger_row.id, next_method=ti_patch_payload.next_method, - next_kwargs=ti_patch_payload.trigger_kwargs, + next_kwargs=next_kwargs, trigger_timeout=timeout, ) updated_state = State.DEFERRED diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ea6a7c34b11..d1ed6d889ac 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -35,7 +35,6 @@ from typing import ( TYPE_CHECKING, Any, Callable, - NoReturn, TypeVar, ) @@ -47,9 +46,6 @@ from sqlalchemy.orm.exc import NoResultFound from airflow.configuration import conf from airflow.exceptions import ( AirflowException, - TaskDeferralError, - TaskDeferralTimeout, - TaskDeferred, ) from airflow.lineage import apply_lineage, prepare_lineage @@ -62,7 +58,6 @@ from airflow.models.abstractoperator import ( from airflow.models.base import _sentinel from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin -from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.sdk.definitions.baseoperator import ( BaseOperatorMeta as TaskSDKBaseOperatorMeta, @@ -98,7 +93,7 @@ if TYPE_CHECKING: from airflow.models.operator import Operator from airflow.sdk.definitions.node import DAGNode from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.triggers.base import BaseTrigger, StartTriggerArgs + from airflow.triggers.base import StartTriggerArgs TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] @@ -741,44 +736,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator """Serialize; required by DAGNode.""" return DagAttributeTypes.OP, self.task_id - def defer( - self, - *, - trigger: BaseTrigger, - method_name: str, - kwargs: dict[str, Any] | None = None, - timeout: timedelta | int | float | None = None, - ) -> NoReturn: - """ - Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. - - This is achieved by raising a special exception (TaskDeferred) - which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end - the task instance directly. If the trigger will end the task instance itself, ``method_name`` should - be None; otherwise, provide the name of the method that should be used when resuming execution in - the task. - """ - raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - - def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): - """Call this method when a deferred task is resumed.""" - # __fail__ is a special signal value for next_method that indicates - # this task was scheduled specifically to fail. - if next_method == TRIGGER_FAIL_REPR: - next_kwargs = next_kwargs or {} - traceback = next_kwargs.get("traceback") - if traceback is not None: - self.log.error("Trigger failed:\n%s", "\n".join(traceback)) - if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: - raise TaskDeferralTimeout(error) - else: - raise TaskDeferralError(error) - # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(self, next_method) - if next_kwargs: - execute_callable = functools.partial(execute_callable, **next_kwargs) - return execute_callable(context) - def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """ Get the "normal" operator from the current operator. diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 71f110b64f4..4db100e8b42 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -117,9 +117,10 @@ class TIDeferredStatePayload(BaseModel): ) state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" classpath: Annotated[str, Field(title="Classpath")] - trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None - next_method: Annotated[str, Field(title="Next Method")] + trigger_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Trigger Kwargs")] = None trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None + next_method: Annotated[str, Field(title="Next Method")] + next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None class TIEnterRunningPayload(BaseModel): @@ -316,6 +317,8 @@ class TIRunContext(BaseModel): variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None + next_method: Annotated[str | None, Field(title="Next Method")] = None + next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None class TITerminalStatePayload(BaseModel): diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 14d67656008..28e88d0c995 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -27,9 +27,10 @@ import warnings from collections.abc import Callable, Collection, Iterable, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta +from enum import Enum from functools import total_ordering, wraps from types import FunctionType -from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast import attrs @@ -77,13 +78,40 @@ if TYPE_CHECKING: from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy + from airflow.triggers.base import BaseTrigger from airflow.typing_compat import Self from airflow.utils.operator_resources import Resources +__all__ = [ + "BaseOperator", +] + # TODO: Task-SDK AirflowException = RuntimeError +class TriggerFailureReason(str, Enum): + """ + Reasons for trigger failures. + + Internal use only. + + :meta private: + """ + + TRIGGER_TIMEOUT = "Trigger timeout" + TRIGGER_FAILURE = "Trigger failure" + + +TRIGGER_FAIL_REPR = "__fail__" +"""String value to represent trigger failure. + +Internal use only. + +:meta private: +""" + + def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: if not dag: return {}, ParamsDict() @@ -1434,3 +1462,44 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): if not jinja_env: jinja_env = self.get_template_env() self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) + + def defer( + self, + *, + trigger: BaseTrigger, + method_name: str, + kwargs: dict[str, Any] | None = None, + timeout: timedelta | int | float | None = None, + ) -> NoReturn: + """ + Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. + + This is achieved by raising a special exception (TaskDeferred) + which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end + the task instance directly. If the trigger will end the task instance itself, ``method_name`` should + be None; otherwise, provide the name of the method that should be used when resuming execution in + the task. + """ + from airflow.exceptions import TaskDeferred + + raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + + def resume_execution(self, next_method: str, next_kwargs: dict[str, Any], context: Context): + """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" + from airflow.exceptions import TaskDeferralError, TaskDeferralTimeout + + # __fail__ is a special signal value for next_method that indicates + # this task was scheduled specifically to fail. + + if next_method == TRIGGER_FAIL_REPR: + next_kwargs = next_kwargs or {} + traceback = next_kwargs.get("traceback") + if traceback is not None: + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) + if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: + raise TaskDeferralTimeout(error) + else: + raise TaskDeferralError(error) + # Grab the callable off the Operator/Task and add in any kwargs + execute_callable = getattr(self, next_method) + return execute_callable(context, **next_kwargs) diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index b0c6dd62abc..f0476f9323a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -44,11 +44,11 @@ Execution API server is because: from __future__ import annotations from datetime import datetime -from typing import Annotated, Literal, Union +from typing import Annotated, Any, Literal, Union from uuid import UUID from fastapi import Body -from pydantic import BaseModel, ConfigDict, Field, JsonValue +from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetResponse, @@ -229,6 +229,19 @@ class DeferTask(TIDeferredStatePayload): type: Literal["DeferTask"] = "DeferTask" + @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True) + def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info): + from airflow.serialization.serialized_objects import BaseSerialization + + if not isinstance(val, dict): + # None, or an encrypted string + return val + + if val.keys() == {"__type", "__var"}: + # Already encoded. + return val + return BaseSerialization.serialize(val or {}) + class RescheduleTask(TIRescheduleStatePayload): """Update a task instance state to reschedule/up_for_reschedule.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index ca95a7003f4..7201357c4d9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -962,16 +962,18 @@ def supervise( # TODO: Use logging providers to handle the chunked upload for us etc. logger: FilteringBoundLogger | None = None if log_path: - # If we are told to write logs to a file, redirect the task logger to it. + # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the + # file though, otherwise when we resume we would loose the logs from the start->deferral segment if it + # lands on the same node as before. from airflow.sdk.log import init_log_file, logging_processors log_file = init_log_file(log_path) pretty_logs = False if pretty_logs: - underlying_logger: WrappedLogger = structlog.WriteLogger(log_file.open("w", buffering=1)) + underlying_logger: WrappedLogger = structlog.WriteLogger(log_file.open("a", buffering=1)) else: - underlying_logger = structlog.BytesLogger(log_file.open("wb")) + underlying_logger = structlog.BytesLogger(log_file.open("ab")) processors = logging_processors(enable_pretty_log=pretty_logs)[0] logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 610c989e01f..97ee45cded5 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -19,6 +19,7 @@ from __future__ import annotations +import functools import os import sys from collections.abc import Iterable, Mapping @@ -616,13 +617,13 @@ def run( # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) classpath, trigger_kwargs = defer.trigger.serialize() - next_method = defer.method_name - defer_timeout = defer.timeout + msg = DeferTask( classpath=classpath, trigger_kwargs=trigger_kwargs, - next_method=next_method, - trigger_timeout=defer_timeout, + trigger_timeout=defer.timeout, + next_method=defer.method_name, + next_kwargs=defer.kwargs or {}, ) state = IntermediateTIState.DEFERRED except AirflowSkipException as e: @@ -697,6 +698,15 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance): from airflow.exceptions import AirflowTaskTimeout task = ti.task + execute = task.execute # type: ignore[attr-defined] + + if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): + from airflow.serialization.serialized_objects import BaseSerialization + + kwargs = BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {}) + + execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) + if task.execution_timeout: # TODO: handle timeout in case of deferral from airflow.utils.timeout import timeout @@ -708,12 +718,12 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance): raise AirflowTaskTimeout() # Run task in timeout wrapper with timeout(timeout_seconds): - result = task.execute(context) # type: ignore[attr-defined] + result = execute(context=context) except AirflowTaskTimeout: # TODO: handle on kill callback here raise else: - result = task.execute(context) # type: ignore[attr-defined] + result = execute(context=context) return result diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 43e35dfec9e..d2309c3e4ab 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -268,14 +268,24 @@ class TestTaskInstanceOperations: # Simulate a successful response from the server that defers a task ti_id = uuid6.uuid7() + msg = DeferTask( + classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", + next_method="execute_complete", + trigger_kwargs={ + "__type": "dict", + "__var": { + "moment": {"__type": "datetime", "__var": 1730982899.0}, + "end_from_trigger": False, + }, + }, + next_kwargs={"__type": "dict", "__var": {}}, + ) + def handle_request(request: httpx.Request) -> httpx.Response: if request.url.path == f"/task-instances/{ti_id}/state": actual_body = json.loads(request.read()) assert actual_body["state"] == "deferred" - assert actual_body["trigger_kwargs"] == { - "moment": "2024-11-07T12:34:59Z", - "end_from_trigger": False, - } + assert actual_body["trigger_kwargs"] == msg.trigger_kwargs assert ( actual_body["classpath"] == "airflow.providers.standard.triggers.temporal.DateTimeTrigger" ) @@ -286,11 +296,6 @@ class TestTaskInstanceOperations: return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - msg = DeferTask( - classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", - trigger_kwargs={"moment": "2024-11-07T12:34:59Z", "end_from_trigger": False}, - next_method="execute_complete", - ) client.task_instances.defer(ti_id, msg) def test_task_instance_reschedule(self): diff --git a/task_sdk/tests/dags/super_basic_deferred_run.py b/task_sdk/tests/dags/super_basic_deferred_run.py index 453d9e5f6c7..934930f4bff 100644 --- a/task_sdk/tests/dags/super_basic_deferred_run.py +++ b/task_sdk/tests/dags/super_basic_deferred_run.py @@ -28,7 +28,7 @@ from airflow.utils import timezone def super_basic_deferred_run(): DateTimeSensorAsync( task_id="async", - target_time=str(timezone.utcnow() + datetime.timedelta(seconds=3)), + target_time=timezone.utcnow() + datetime.timedelta(seconds=3), poke_interval=60, timeout=600, ) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index ececcbe31d8..fde63fc7091 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -347,10 +347,18 @@ class TestWatchedSubprocess: mock_client.task_instances.heartbeat.assert_called_once_with(ti.id, pid=mocker.ANY) mock_client.task_instances.defer.assert_called_once_with( ti.id, + # Since the message as serialized in the client upon sending, we expect it to be already encoded DeferTask( classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", - trigger_kwargs={"moment": "2024-11-07T12:34:59Z", "end_from_trigger": False}, next_method="execute_complete", + trigger_kwargs={ + "__type": "dict", + "__var": { + "moment": {"__type": "datetime", "__var": 1730982899.0}, + "end_from_trigger": False, + }, + }, + next_kwargs={"__type": "dict", "__var": {}}, ), ) diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 66a585d7b01..92f64fae030 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -24,6 +24,7 @@ import uuid from datetime import datetime, timedelta from pathlib import Path from socket import socketpair +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -86,6 +87,9 @@ from airflow.utils.state import TaskInstanceState from tests_common.test_utils.mock_operators import AirflowLink +if TYPE_CHECKING: + from kgb import SpyAgency + FAKE_BUNDLE = BundleInfo(name="anything", version="any") @@ -183,15 +187,13 @@ def test_parse(test_dags_dir: Path, make_ti_context): def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms): """Test that a task can transition to a deferred state.""" - import datetime - from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync # Use the time machine to set the current time instant = timezone.datetime(2024, 11, 22) task = DateTimeSensorAsync( task_id="async", - target_time=str(instant + datetime.timedelta(seconds=3)), + target_time=str(instant + timedelta(seconds=3)), poke_interval=60, timeout=600, ) @@ -201,22 +203,51 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com expected_defer_task = DeferTask( state="deferred", classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", + # Since we are in the task process here, we expect this to have not been encoded by serde yet trigger_kwargs={ "end_from_trigger": False, "moment": instant + timedelta(seconds=3), }, - next_method="execute_complete", trigger_timeout=None, + next_method="execute_complete", + next_kwargs={}, ) # Run the task ti = create_runtime_ti(dag_id="basic_deferred_run", task=task) - run(ti, log=mock.MagicMock()) + state, msg, err = run(ti, log=mock.MagicMock()) # send_request will only be called when the TaskDeferred exception is raised mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) +def test_resume_from_deferred(time_machine, create_runtime_ti, mock_supervisor_comms, spy_agency: SpyAgency): + from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync + + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + task = DateTimeSensorAsync( + task_id="async", + target_time=instant + timedelta(seconds=3), + poke_interval=60, + timeout=600, + ) + + ti = create_runtime_ti(dag_id="basic_deferred_run", task=task) + ti._ti_context_from_server.next_method = "execute_complete" + ti._ti_context_from_server.next_kwargs = { + "__type": "dict", + "__var": {"event": {"__type": "datetime", "__var": 1727697600.0}}, + } + + spy = spy_agency.spy_on(task.execute_complete) + state, msg, err = run(ti, log=mock.MagicMock()) + assert err is None + assert state == TaskInstanceState.SUCCESS + + spy_agency.assert_spy_called_with(spy, mock.ANY, event=instant) + + def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comms): """Test running a basic task that marks itself skipped.""" diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 5542adb9844..5477869a6cb 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -77,7 +77,6 @@ class TestTIRunState: session=session, start_date=instant, ) - session.commit() response = client.patch( @@ -119,6 +118,49 @@ class TestTIRunState: assert ti.unixname == "random-unixname" assert ti.pid == 100 + def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine): + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_ti_run_state_to_running", + state=State.QUEUED, + session=session, + start_date=instant, + ) + + ti.next_method = "execute_complete" + # ti.next_kwargs under the hood applies the serde encoding for us + ti.next_kwargs = {"moment": instant} + + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "dag_run": mock.ANY, + "task_reschedule_count": 0, + "max_tries": 0, + "variables": [], + "connections": [], + "next_method": "execute_complete", + "next_kwargs": { + "__type": "dict", + "__var": {"moment": {"__type": "datetime", "__var": 1727697600.0}}, + }, + } + @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED]) def test_ti_run_state_conflict_if_not_queued( self, client, session, create_task_instance, initial_ti_state @@ -209,9 +251,9 @@ class TestTIRunState: payload = { "state": "deferred", "trigger_kwargs": {"key": "value", "moment": "2024-12-18T00:00:00Z"}, + "trigger_timeout": "P1D", # 1 day "classpath": "my-classpath", "next_method": "execute_callback", - "trigger_timeout": "P1D", # 1 day } response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) @@ -452,10 +494,18 @@ class TestTIUpdateState: payload = { "state": "deferred", - "trigger_kwargs": {"key": "value", "moment": "2024-12-18T00:00:00Z"}, + # Raw payload is already "encoded", but not encrypted + "trigger_kwargs": { + "__type": "dict", + "__var": {"key": "value", "moment": {"__type": "datetime", "__var": 1734480001.0}}, + }, + "trigger_timeout": "P1D", # 1 day "classpath": "my-classpath", "next_method": "execute_callback", - "trigger_timeout": "P1D", # 1 day + "next_kwargs": { + "__type": "dict", + "__var": {"foo": {"__type": "datetime", "__var": 1734480000.0}, "bar": "abc"}, + }, } response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) @@ -471,8 +521,8 @@ class TestTIUpdateState: assert tis[0].state == TaskInstanceState.DEFERRED assert tis[0].next_method == "execute_callback" assert tis[0].next_kwargs == { - "key": "value", - "moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc), + "bar": "abc", + "foo": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc), } assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 11, 23), timezone=timezone.utc) @@ -482,7 +532,7 @@ class TestTIUpdateState: assert t[0].classpath == "my-classpath" assert t[0].kwargs == { "key": "value", - "moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc), + "moment": datetime(2024, 12, 18, 00, 00, 1, tzinfo=timezone.utc), } def test_ti_update_state_to_reschedule(self, client, session, create_task_instance, time_machine):
