This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fbbe59a2a92 Correctly support resuming tasks after triggers (#47061)
fbbe59a2a92 is described below
commit fbbe59a2a927281042f5dbad76143f444e2a82ec
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue Feb 25 22:11:08 2025 +0000
Correctly support resuming tasks after triggers (#47061)
There was a number of issues here that prevented us from correctly resuming
a
task after it's trigger fired.
First off, only ever ran the execute method and we didn't respect
next_method.
So we added that info the the TI context we send in the server in response
to the `.../run` endpoint.
Once we'd fixed that, we then ran into the net problem that we were
incorrectly setting the same value for trigger_kwargs (which are the kwargs
we
pass to the trigger constructor) and to the next_kwargs (which are the
kwargs
to use when resuming the task) -- they needed to be different. This involved
adding the new field (`kwargs`) onto the TIDeferredStatePayload.
The next complication after that was the "ExtendedJSON" type on the
next_kwargs column of TI: this is a type decorator that automatically
applies
the BaseSerialization encode/decode step (__var and __type etc). The problem
with that is that we need to do the serialization on the client in order to
send a JSON HTTP request, so we don't want to encode _again_ on the server
ideally. I was able to do that easily on the write/update side but not so
easily on the read side -- there I left a comment and for now we will hae
SQLA
decode it for us, and then we have to encode it again. Not the best, but
not a
disaster either.
The other change I did here was to have the DeferTask automatically apply
the
serde encoding when serializing, just so that there are fewer places in the
code that need to be aware of that detail (So the Task subprocess will
encode
it before making the request toe the Supervisor, and in the Supervisor it
will be
kept encoded and passed along as is to it's HTTP request). This means you
can
once again pass datetime objects to a trigger "natively", not only strings.
For consistency with the user facing code I renamed `next_method` on
`DeferTask` message to `method_name`. I'm not sure this really makes sense
on
in the API request to the API server though.
Co-authored-by: Amogh Desai <[email protected]>
Co-authored-by: Jed Cunningham
<[email protected]>
---
.../execution_api/datamodels/taskinstance.py | 31 ++++++---
.../execution_api/routes/task_instances.py | 53 ++++++++++++++--
airflow/models/baseoperator.py | 45 +------------
.../src/airflow/sdk/api/datamodels/_generated.py | 7 ++-
.../src/airflow/sdk/definitions/baseoperator.py | 73 +++++++++++++++++++++-
task_sdk/src/airflow/sdk/execution_time/comms.py | 17 ++++-
.../src/airflow/sdk/execution_time/supervisor.py | 8 ++-
.../src/airflow/sdk/execution_time/task_runner.py | 22 +++++--
task_sdk/tests/api/test_client.py | 23 ++++---
task_sdk/tests/dags/super_basic_deferred_run.py | 2 +-
task_sdk/tests/execution_time/test_supervisor.py | 10 ++-
task_sdk/tests/execution_time/test_task_runner.py | 41 ++++++++++--
.../execution_api/routes/test_task_instances.py | 64 ++++++++++++++++---
13 files changed, 301 insertions(+), 95 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index a176a2f9d5b..f6e586347d3 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -28,7 +28,6 @@ from pydantic import (
Tag,
TypeAdapter,
WithJsonSchema,
- field_validator,
)
from airflow.api_fastapi.common.types import UtcDateTime
@@ -122,15 +121,22 @@ class TIDeferredStatePayload(StrictBaseModel):
),
]
classpath: str
- trigger_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
- next_method: str
+ trigger_kwargs: Annotated[dict[str, Any] | str,
Field(default_factory=dict)]
+ """
+ Kwargs to pass to the trigger constructor, either a plain dict or an
encrypted string.
+
+ Both forms will be passed along to the trigger, the server will not handle
either.
+ """
+
trigger_timeout: timedelta | None = None
+ next_method: str
+ """The name of the method on the operator to call in the worker after the
trigger has fired."""
+ next_kwargs: Annotated[dict[str, Any] | str, Field(default_factory=dict)]
+ """
+ Kwargs to pass to the above method, either a plain dict or an encrypted
string.
- @field_validator("trigger_kwargs")
- def validate_moment(cls, v):
- if "moment" in v:
- v["moment"] = AwareDatetimeAdapter.validate_strings(v["moment"])
- return v
+ Both forms will be passed along to the TaskSDK upon resume, the server
will not handle either.
+ """
class TIRescheduleStatePayload(StrictBaseModel):
@@ -252,6 +258,15 @@ class TIRunContext(BaseModel):
upstream_map_indexes: dict[str, int] | None = None
+ next_method: str | None = None
+ """Method to call. Set when task resumes from a trigger."""
+ next_kwargs: dict[str, Any] | str | None = None
+ """
+ Args to pass to ``next_method``.
+
+ Can either be a "decorated" dict, or a string encrypted with the shared
Fernet key.
+ """
+
class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for
Task Template Context."""
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 87cc547a9f0..6a403c1957a 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+import json
import logging
from typing import Annotated
from uuid import UUID
@@ -78,6 +79,9 @@ def ti_run(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)
+ from sqlalchemy.sql import column
+ from sqlalchemy.types import JSON
+
old = (
select(
TI.state,
@@ -88,14 +92,28 @@ def ti_run(
TI.next_method,
TI.try_number,
TI.max_tries,
+ TI.next_method,
+ # This selects the raw JSON value, by-passing the deserialization
-- we want that to happen on the
+ # client
+ column("next_kwargs", JSON),
)
+ .select_from(TI)
.where(TI.id == ti_id_str)
.with_for_update()
)
try:
- (previous_state, dag_id, run_id, task_id, map_index, next_method,
try_number, max_tries) = (
- session.execute(old).one()
- )
+ (
+ previous_state,
+ dag_id,
+ run_id,
+ task_id,
+ map_index,
+ next_method,
+ try_number,
+ max_tries,
+ next_method,
+ next_kwargs,
+ ) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
@@ -195,7 +213,7 @@ def ti_run(
or 0
)
- return TIRunContext(
+ context = TIRunContext(
dag_run=dr,
task_reschedule_count=task_reschedule_count,
max_tries=max_tries,
@@ -203,6 +221,13 @@ def ti_run(
variables=[],
connections=[],
)
+
+ # Only set if they are non-null
+ if next_method:
+ context.next_method = next_method
+ context.next_kwargs = next_kwargs
+
+ return context
except SQLAlchemyError as e:
log.error("Error marking Task Instance state as running: %s", e)
raise HTTPException(
@@ -290,10 +315,17 @@ def ti_update_state(
if ti_patch_payload.trigger_timeout is not None:
timeout = timezone.utcnow() + ti_patch_payload.trigger_timeout
+ trigger_kwargs = ti_patch_payload.trigger_kwargs
+ if not isinstance(trigger_kwargs, str):
+ # If it's passed as a string, assume the client encrypted it,
otherwise assume it doesn't need to
+ # be. Just JSON serialize it
+ trigger_kwargs = json.dumps(trigger_kwargs)
+
trigger_row = Trigger(
classpath=ti_patch_payload.classpath,
- kwargs=ti_patch_payload.trigger_kwargs,
+ kwargs={},
)
+ trigger_row.encrypted_kwargs = trigger_kwargs
session.add(trigger_row)
session.flush()
@@ -301,11 +333,20 @@ def ti_update_state(
# either get it from the serialised DAG or get it from the API
query = update(TI).where(TI.id == ti_id_str)
+
+ # This is slightly inefficient as we deserialize it to then right
again serialize it in the sqla
+ # TypeAdapter.
+ next_kwargs = None
+ if ti_patch_payload.next_kwargs:
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ next_kwargs =
BaseSerialization.deserialize(ti_patch_payload.next_kwargs)
+
query = query.values(
state=State.DEFERRED,
trigger_id=trigger_row.id,
next_method=ti_patch_payload.next_method,
- next_kwargs=ti_patch_payload.trigger_kwargs,
+ next_kwargs=next_kwargs,
trigger_timeout=timeout,
)
updated_state = State.DEFERRED
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index ea6a7c34b11..d1ed6d889ac 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -35,7 +35,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
- NoReturn,
TypeVar,
)
@@ -47,9 +46,6 @@ from sqlalchemy.orm.exc import NoResultFound
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
- TaskDeferralError,
- TaskDeferralTimeout,
- TaskDeferred,
)
from airflow.lineage import apply_lineage, prepare_lineage
@@ -62,7 +58,6 @@ from airflow.models.abstractoperator import (
from airflow.models.base import _sentinel
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
-from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator as TaskSDKAbstractOperator
from airflow.sdk.definitions.baseoperator import (
BaseOperatorMeta as TaskSDKBaseOperatorMeta,
@@ -98,7 +93,7 @@ if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
- from airflow.triggers.base import BaseTrigger, StartTriggerArgs
+ from airflow.triggers.base import StartTriggerArgs
TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]
@@ -741,44 +736,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator,
metaclass=BaseOperator
"""Serialize; required by DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def defer(
- self,
- *,
- trigger: BaseTrigger,
- method_name: str,
- kwargs: dict[str, Any] | None = None,
- timeout: timedelta | int | float | None = None,
- ) -> NoReturn:
- """
- Mark this Operator "deferred", suspending its execution until the
provided trigger fires an event.
-
- This is achieved by raising a special exception (TaskDeferred)
- which is caught in the main _execute_task wrapper. Triggers can send
execution back to task or end
- the task instance directly. If the trigger will end the task instance
itself, ``method_name`` should
- be None; otherwise, provide the name of the method that should be used
when resuming execution in
- the task.
- """
- raise TaskDeferred(trigger=trigger, method_name=method_name,
kwargs=kwargs, timeout=timeout)
-
- def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] |
None, context: Context):
- """Call this method when a deferred task is resumed."""
- # __fail__ is a special signal value for next_method that indicates
- # this task was scheduled specifically to fail.
- if next_method == TRIGGER_FAIL_REPR:
- next_kwargs = next_kwargs or {}
- traceback = next_kwargs.get("traceback")
- if traceback is not None:
- self.log.error("Trigger failed:\n%s", "\n".join(traceback))
- if (error := next_kwargs.get("error", "Unknown")) ==
TriggerFailureReason.TRIGGER_TIMEOUT:
- raise TaskDeferralTimeout(error)
- else:
- raise TaskDeferralError(error)
- # Grab the callable off the Operator/Task and add in any kwargs
- execute_callable = getattr(self, next_method)
- if next_kwargs:
- execute_callable = functools.partial(execute_callable,
**next_kwargs)
- return execute_callable(context)
-
def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session])
-> BaseOperator:
"""
Get the "normal" operator from the current operator.
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 71f110b64f4..4db100e8b42 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -117,9 +117,10 @@ class TIDeferredStatePayload(BaseModel):
)
state: Annotated[Literal["deferred"] | None, Field(title="State")] =
"deferred"
classpath: Annotated[str, Field(title="Classpath")]
- trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger
Kwargs")] = None
- next_method: Annotated[str, Field(title="Next Method")]
+ trigger_kwargs: Annotated[dict[str, Any] | str | None,
Field(title="Trigger Kwargs")] = None
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger
Timeout")] = None
+ next_method: Annotated[str, Field(title="Next Method")]
+ next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next
Kwargs")] = None
class TIEnterRunningPayload(BaseModel):
@@ -316,6 +317,8 @@ class TIRunContext(BaseModel):
variables: Annotated[list[VariableResponse] | None,
Field(title="Variables")] = None
connections: Annotated[list[ConnectionResponse] | None,
Field(title="Connections")] = None
upstream_map_indexes: Annotated[dict[str, int] | None,
Field(title="Upstream Map Indexes")] = None
+ next_method: Annotated[str | None, Field(title="Next Method")] = None
+ next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next
Kwargs")] = None
class TITerminalStatePayload(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 14d67656008..3efbd1ad0b6 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -27,9 +27,10 @@ import warnings
from collections.abc import Callable, Collection, Iterable, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timedelta
+from enum import Enum
from functools import total_ordering, wraps
from types import FunctionType
-from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast
+from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
import attrs
@@ -77,13 +78,40 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy
+ from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self
from airflow.utils.operator_resources import Resources
+__all__ = [
+ "BaseOperator",
+]
+
# TODO: Task-SDK
AirflowException = RuntimeError
+class TriggerFailureReason(str, Enum):
+ """
+ Reasons for trigger failures.
+
+ Internal use only.
+
+ :meta private:
+ """
+
+ TRIGGER_TIMEOUT = "Trigger timeout"
+ TRIGGER_FAILURE = "Trigger failure"
+
+
+TRIGGER_FAIL_REPR = "__fail__"
+"""String value to represent trigger failure.
+
+Internal use only.
+
+:meta private:
+"""
+
+
def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) ->
tuple[dict, ParamsDict]:
if not dag:
return {}, ParamsDict()
@@ -1434,3 +1462,46 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context,
jinja_env, set())
+
+ def defer(
+ self,
+ *,
+ trigger: BaseTrigger,
+ method_name: str,
+ kwargs: dict[str, Any] | None = None,
+ timeout: timedelta | int | float | None = None,
+ ) -> NoReturn:
+ """
+ Mark this Operator "deferred", suspending its execution until the
provided trigger fires an event.
+
+ This is achieved by raising a special exception (TaskDeferred)
+ which is caught in the main _execute_task wrapper. Triggers can send
execution back to task or end
+ the task instance directly. If the trigger will end the task instance
itself, ``method_name`` should
+ be None; otherwise, provide the name of the method that should be used
when resuming execution in
+ the task.
+ """
+ from airflow.exceptions import TaskDeferred
+
+ raise TaskDeferred(trigger=trigger, method_name=method_name,
kwargs=kwargs, timeout=timeout)
+
+ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] |
None, context: Context):
+ """Entrypoint method called by the Task Runner (instead of execute)
when this task is resumed."""
+ from airflow.exceptions import TaskDeferralError, TaskDeferralTimeout
+
+ if next_kwargs is None:
+ next_kwargs = {}
+ # __fail__ is a special signal value for next_method that indicates
+ # this task was scheduled specifically to fail.
+
+ if next_method == TRIGGER_FAIL_REPR:
+ next_kwargs = next_kwargs or {}
+ traceback = next_kwargs.get("traceback")
+ if traceback is not None:
+ self.log.error("Trigger failed:\n%s", "\n".join(traceback))
+ if (error := next_kwargs.get("error", "Unknown")) ==
TriggerFailureReason.TRIGGER_TIMEOUT:
+ raise TaskDeferralTimeout(error)
+ else:
+ raise TaskDeferralError(error)
+ # Grab the callable off the Operator/Task and add in any kwargs
+ execute_callable = getattr(self, next_method)
+ return execute_callable(context, **next_kwargs)
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index b0c6dd62abc..f0476f9323a 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -44,11 +44,11 @@ Execution API server is because:
from __future__ import annotations
from datetime import datetime
-from typing import Annotated, Literal, Union
+from typing import Annotated, Any, Literal, Union
from uuid import UUID
from fastapi import Body
-from pydantic import BaseModel, ConfigDict, Field, JsonValue
+from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
@@ -229,6 +229,19 @@ class DeferTask(TIDeferredStatePayload):
type: Literal["DeferTask"] = "DeferTask"
+ @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
+ def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ if not isinstance(val, dict):
+ # None, or an encrypted string
+ return val
+
+ if val.keys() == {"__type", "__var"}:
+ # Already encoded.
+ return val
+ return BaseSerialization.serialize(val or {})
+
class RescheduleTask(TIRescheduleStatePayload):
"""Update a task instance state to reschedule/up_for_reschedule."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index ca95a7003f4..fada6105bce 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -962,16 +962,18 @@ def supervise(
# TODO: Use logging providers to handle the chunked upload for us etc.
logger: FilteringBoundLogger | None = None
if log_path:
- # If we are told to write logs to a file, redirect the task logger to
it.
+ # If we are told to write logs to a file, redirect the task logger to
it. Make sure we append to the
+ # file though, otherwise when we resume we would lose the logs from
the start->deferral segment if it
+ # lands on the same node as before.
from airflow.sdk.log import init_log_file, logging_processors
log_file = init_log_file(log_path)
pretty_logs = False
if pretty_logs:
- underlying_logger: WrappedLogger =
structlog.WriteLogger(log_file.open("w", buffering=1))
+ underlying_logger: WrappedLogger =
structlog.WriteLogger(log_file.open("a", buffering=1))
else:
- underlying_logger = structlog.BytesLogger(log_file.open("wb"))
+ underlying_logger = structlog.BytesLogger(log_file.open("ab"))
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger,
processors=processors, logger_name="task").bind()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 610c989e01f..97ee45cded5 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import functools
import os
import sys
from collections.abc import Iterable, Mapping
@@ -616,13 +617,13 @@ def run(
# TODO: Should we use structlog.bind_contextvars here for dag_id,
task_id & run_id?
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id,
task_id=ti.task_id, run_id=ti.run_id)
classpath, trigger_kwargs = defer.trigger.serialize()
- next_method = defer.method_name
- defer_timeout = defer.timeout
+
msg = DeferTask(
classpath=classpath,
trigger_kwargs=trigger_kwargs,
- next_method=next_method,
- trigger_timeout=defer_timeout,
+ trigger_timeout=defer.timeout,
+ next_method=defer.method_name,
+ next_kwargs=defer.kwargs or {},
)
state = IntermediateTIState.DEFERRED
except AirflowSkipException as e:
@@ -697,6 +698,15 @@ def _execute_task(context: Context, ti:
RuntimeTaskInstance):
from airflow.exceptions import AirflowTaskTimeout
task = ti.task
+ execute = task.execute # type: ignore[attr-defined]
+
+ if ti._ti_context_from_server and (next_method :=
ti._ti_context_from_server.next_method):
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ kwargs =
BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
+
+ execute = functools.partial(task.resume_execution,
next_method=next_method, next_kwargs=kwargs)
+
if task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout
@@ -708,12 +718,12 @@ def _execute_task(context: Context, ti:
RuntimeTaskInstance):
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
- result = task.execute(context) # type: ignore[attr-defined]
+ result = execute(context=context)
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
- result = task.execute(context) # type: ignore[attr-defined]
+ result = execute(context=context)
return result
diff --git a/task_sdk/tests/api/test_client.py
b/task_sdk/tests/api/test_client.py
index 43e35dfec9e..d2309c3e4ab 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -268,14 +268,24 @@ class TestTaskInstanceOperations:
# Simulate a successful response from the server that defers a task
ti_id = uuid6.uuid7()
+ msg = DeferTask(
+
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
+ next_method="execute_complete",
+ trigger_kwargs={
+ "__type": "dict",
+ "__var": {
+ "moment": {"__type": "datetime", "__var": 1730982899.0},
+ "end_from_trigger": False,
+ },
+ },
+ next_kwargs={"__type": "dict", "__var": {}},
+ )
+
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "deferred"
- assert actual_body["trigger_kwargs"] == {
- "moment": "2024-11-07T12:34:59Z",
- "end_from_trigger": False,
- }
+ assert actual_body["trigger_kwargs"] == msg.trigger_kwargs
assert (
actual_body["classpath"] ==
"airflow.providers.standard.triggers.temporal.DateTimeTrigger"
)
@@ -286,11 +296,6 @@ class TestTaskInstanceOperations:
return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
- msg = DeferTask(
-
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
- trigger_kwargs={"moment": "2024-11-07T12:34:59Z",
"end_from_trigger": False},
- next_method="execute_complete",
- )
client.task_instances.defer(ti_id, msg)
def test_task_instance_reschedule(self):
diff --git a/task_sdk/tests/dags/super_basic_deferred_run.py
b/task_sdk/tests/dags/super_basic_deferred_run.py
index 453d9e5f6c7..934930f4bff 100644
--- a/task_sdk/tests/dags/super_basic_deferred_run.py
+++ b/task_sdk/tests/dags/super_basic_deferred_run.py
@@ -28,7 +28,7 @@ from airflow.utils import timezone
def super_basic_deferred_run():
DateTimeSensorAsync(
task_id="async",
- target_time=str(timezone.utcnow() + datetime.timedelta(seconds=3)),
+ target_time=timezone.utcnow() + datetime.timedelta(seconds=3),
poke_interval=60,
timeout=600,
)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index ececcbe31d8..fde63fc7091 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -347,10 +347,18 @@ class TestWatchedSubprocess:
mock_client.task_instances.heartbeat.assert_called_once_with(ti.id,
pid=mocker.ANY)
mock_client.task_instances.defer.assert_called_once_with(
ti.id,
+ # Since the message as serialized in the client upon sending, we
expect it to be already encoded
DeferTask(
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
- trigger_kwargs={"moment": "2024-11-07T12:34:59Z",
"end_from_trigger": False},
next_method="execute_complete",
+ trigger_kwargs={
+ "__type": "dict",
+ "__var": {
+ "moment": {"__type": "datetime", "__var":
1730982899.0},
+ "end_from_trigger": False,
+ },
+ },
+ next_kwargs={"__type": "dict", "__var": {}},
),
)
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 66a585d7b01..92f64fae030 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -24,6 +24,7 @@ import uuid
from datetime import datetime, timedelta
from pathlib import Path
from socket import socketpair
+from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
@@ -86,6 +87,9 @@ from airflow.utils.state import TaskInstanceState
from tests_common.test_utils.mock_operators import AirflowLink
+if TYPE_CHECKING:
+ from kgb import SpyAgency
+
FAKE_BUNDLE = BundleInfo(name="anything", version="any")
@@ -183,15 +187,13 @@ def test_parse(test_dags_dir: Path, make_ti_context):
def test_run_deferred_basic(time_machine, create_runtime_ti,
mock_supervisor_comms):
"""Test that a task can transition to a deferred state."""
- import datetime
-
from airflow.providers.standard.sensors.date_time import
DateTimeSensorAsync
# Use the time machine to set the current time
instant = timezone.datetime(2024, 11, 22)
task = DateTimeSensorAsync(
task_id="async",
- target_time=str(instant + datetime.timedelta(seconds=3)),
+ target_time=str(instant + timedelta(seconds=3)),
poke_interval=60,
timeout=600,
)
@@ -201,22 +203,51 @@ def test_run_deferred_basic(time_machine,
create_runtime_ti, mock_supervisor_com
expected_defer_task = DeferTask(
state="deferred",
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
+ # Since we are in the task process here, we expect this to have not
been encoded by serde yet
trigger_kwargs={
"end_from_trigger": False,
"moment": instant + timedelta(seconds=3),
},
- next_method="execute_complete",
trigger_timeout=None,
+ next_method="execute_complete",
+ next_kwargs={},
)
# Run the task
ti = create_runtime_ti(dag_id="basic_deferred_run", task=task)
- run(ti, log=mock.MagicMock())
+ state, msg, err = run(ti, log=mock.MagicMock())
# send_request will only be called when the TaskDeferred exception is
raised
mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task,
log=mock.ANY)
+def test_resume_from_deferred(time_machine, create_runtime_ti,
mock_supervisor_comms, spy_agency: SpyAgency):
+ from airflow.providers.standard.sensors.date_time import
DateTimeSensorAsync
+
+ instant_str = "2024-09-30T12:00:00Z"
+ instant = timezone.parse(instant_str)
+ task = DateTimeSensorAsync(
+ task_id="async",
+ target_time=instant + timedelta(seconds=3),
+ poke_interval=60,
+ timeout=600,
+ )
+
+ ti = create_runtime_ti(dag_id="basic_deferred_run", task=task)
+ ti._ti_context_from_server.next_method = "execute_complete"
+ ti._ti_context_from_server.next_kwargs = {
+ "__type": "dict",
+ "__var": {"event": {"__type": "datetime", "__var": 1727697600.0}},
+ }
+
+ spy = spy_agency.spy_on(task.execute_complete)
+ state, msg, err = run(ti, log=mock.MagicMock())
+ assert err is None
+ assert state == TaskInstanceState.SUCCESS
+
+ spy_agency.assert_spy_called_with(spy, mock.ANY, event=instant)
+
+
def test_run_basic_skipped(time_machine, create_runtime_ti,
mock_supervisor_comms):
"""Test running a basic task that marks itself skipped."""
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 5542adb9844..5477869a6cb 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -77,7 +77,6 @@ class TestTIRunState:
session=session,
start_date=instant,
)
-
session.commit()
response = client.patch(
@@ -119,6 +118,49 @@ class TestTIRunState:
assert ti.unixname == "random-unixname"
assert ti.pid == 100
+ def test_next_kwargs_still_encoded(self, client, session,
create_task_instance, time_machine):
+ instant_str = "2024-09-30T12:00:00Z"
+ instant = timezone.parse(instant_str)
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_ti_run_state_to_running",
+ state=State.QUEUED,
+ session=session,
+ start_date=instant,
+ )
+
+ ti.next_method = "execute_complete"
+ # ti.next_kwargs under the hood applies the serde encoding for us
+ ti.next_kwargs = {"moment": instant}
+
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": instant_str,
+ },
+ )
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "dag_run": mock.ANY,
+ "task_reschedule_count": 0,
+ "max_tries": 0,
+ "variables": [],
+ "connections": [],
+ "next_method": "execute_complete",
+ "next_kwargs": {
+ "__type": "dict",
+ "__var": {"moment": {"__type": "datetime", "__var":
1727697600.0}},
+ },
+ }
+
@pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState
if s != State.QUEUED])
def test_ti_run_state_conflict_if_not_queued(
self, client, session, create_task_instance, initial_ti_state
@@ -209,9 +251,9 @@ class TestTIRunState:
payload = {
"state": "deferred",
"trigger_kwargs": {"key": "value", "moment":
"2024-12-18T00:00:00Z"},
+ "trigger_timeout": "P1D", # 1 day
"classpath": "my-classpath",
"next_method": "execute_callback",
- "trigger_timeout": "P1D", # 1 day
}
response = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
@@ -452,10 +494,18 @@ class TestTIUpdateState:
payload = {
"state": "deferred",
- "trigger_kwargs": {"key": "value", "moment":
"2024-12-18T00:00:00Z"},
+ # Raw payload is already "encoded", but not encrypted
+ "trigger_kwargs": {
+ "__type": "dict",
+ "__var": {"key": "value", "moment": {"__type": "datetime",
"__var": 1734480001.0}},
+ },
+ "trigger_timeout": "P1D", # 1 day
"classpath": "my-classpath",
"next_method": "execute_callback",
- "trigger_timeout": "P1D", # 1 day
+ "next_kwargs": {
+ "__type": "dict",
+ "__var": {"foo": {"__type": "datetime", "__var":
1734480000.0}, "bar": "abc"},
+ },
}
response = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
@@ -471,8 +521,8 @@ class TestTIUpdateState:
assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].next_method == "execute_callback"
assert tis[0].next_kwargs == {
- "key": "value",
- "moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
+ "bar": "abc",
+ "foo": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}
assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024,
11, 23), timezone=timezone.utc)
@@ -482,7 +532,7 @@ class TestTIUpdateState:
assert t[0].classpath == "my-classpath"
assert t[0].kwargs == {
"key": "value",
- "moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
+ "moment": datetime(2024, 12, 18, 00, 00, 1, tzinfo=timezone.utc),
}
def test_ti_update_state_to_reschedule(self, client, session,
create_task_instance, time_machine):