This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 309d658204b Refactor task runner for spans (#62589)
309d658204b is described below
commit 309d658204b418913c6ed23341370e222240c6c7
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Mar 2 19:12:52 2026 -0800
Refactor task runner for spans (#62589)
We need to get the ti details object so we can get the context carrier, so
that we can open the span.
This small adjustment lets us get the ti details message separately from
the rest of startup, which will make it clean to enclose all of it in a span.
---
task-sdk/src/airflow/sdk/execution_time/task_runner.py | 9 +++++++--
.../tests/task_sdk/execution_time/test_task_runner.py | 17 ++++++++++-------
2 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 7355f886ee8..0ef7a74a24b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -843,7 +843,7 @@ def _verify_bundle_access(bundle_instance: BaseDagBundle,
log: Logger) -> None:
)
-def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
+def get_startup_details() -> StartupDetails:
# The parent sends us a StartupDetails message un-prompted. After this,
every single message is only sent
# in response to us sending a request.
log = structlog.get_logger(logger_name="task")
@@ -867,7 +867,11 @@ def startup() -> tuple[RuntimeTaskInstance, Context,
Logger]:
if not isinstance(msg, StartupDetails):
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
+ return msg
+
+def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context,
Logger]:
+ log = structlog.get_logger("task")
# setproctitle causes issue on Mac OS:
https://github.com/benoitc/gunicorn/issues/3021
os_type = sys.platform
if os_type == "darwin":
@@ -1803,7 +1807,8 @@ def main():
try:
try:
- ti, context, log = startup()
+ startup_details = get_startup_details()
+ ti, context, log = startup(msg=startup_details)
except AirflowRescheduleException as reschedule:
log.warning("Rescheduling task during startup, marking task as
UP_FOR_RESCHEDULE")
SUPERVISOR_COMMS.send(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f40b4399cff..2a495e557f7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -130,6 +130,7 @@ from airflow.sdk.execution_time.task_runner import (
_push_xcom_if_needed,
_xcom_push,
finalize,
+ get_startup_details,
parse,
run,
startup,
@@ -363,9 +364,10 @@ def
test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags
@mock.patch("builtins.exit", side_effect=lambda code: (_ for _ in
()).throw(SystemExit(code)))
@mock.patch("airflow.sdk.execution_time.task_runner.startup")
[email protected]("airflow.sdk.execution_time.task_runner.get_startup_details")
@mock.patch("airflow.sdk.execution_time.task_runner.CommsDecoder")
def test_main_sends_reschedule_task_when_startup_reschedules(
- mock_comms_decoder_cls, mock_startup, mock_exit, time_machine
+ mock_comms_decoder_cls, mock_get_startup_details, mock_startup, mock_exit,
time_machine
):
"""
If startup raises AirflowRescheduleException, the task runner should
report a RescheduleTask
@@ -377,6 +379,7 @@ def
test_main_sends_reschedule_task_when_startup_reschedules(
mock_comms_instance = mock.Mock()
mock_comms_instance.socket = None
mock_comms_decoder_cls.__getitem__.return_value.return_value =
mock_comms_instance
+ mock_get_startup_details.return_value = mock.Mock()
mock_startup.side_effect =
AirflowRescheduleException(reschedule_date=reschedule_date)
# Move time
@@ -927,7 +930,7 @@ def test_startup_and_run_dag_with_rtif(
mock_supervisor_comms._get_response.return_value = what
- run(*startup())
+ run(*startup(get_startup_details()))
expected_calls = [
mock.call.send(SetRenderedFields(rendered_fields=expected_rendered_fields)),
mock.call.send(
@@ -977,7 +980,7 @@ def test_task_run_with_user_impersonation(
mock_supervisor_comms.socket.fileno.return_value = 42
with mock.patch.dict(os.environ, {}, clear=True):
- startup()
+ startup(get_startup_details())
assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
assert "_AIRFLOW__STARTUP_MSG" in os.environ
@@ -1026,7 +1029,7 @@ def test_task_run_with_user_impersonation_default_user(
mock_get_user.return_value = "default_user"
with mock.patch.dict(os.environ, {}, clear=True):
- startup()
+ startup(get_startup_details())
assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ
assert "_AIRFLOW__STARTUP_MSG" not in os.environ
@@ -1069,7 +1072,7 @@ def
test_task_run_with_user_impersonation_remove_krb5ccname_on_reexecuted_proces
"_AIRFLOW__STARTUP_MSG": what.model_dump_json(),
}
with mock.patch.dict("os.environ", mock_os_env, clear=True):
- startup()
+ startup(get_startup_details())
assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
assert "_AIRFLOW__STARTUP_MSG" in os.environ
@@ -1241,7 +1244,7 @@ def test_dag_parsing_context(make_ti_context,
mock_supervisor_comms, monkeypatch
)
monkeypatch.setenv("AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST",
dag_bundle_val)
- ti, _, _ = startup()
+ ti, _, _ = startup(get_startup_details())
# Presence of `conditional_task` below means Dag ID is properly set in the
parsing context!
# Check the dag file for the actual logic!
@@ -3596,7 +3599,7 @@ class TestTaskRunnerCallsListeners:
mock_supervisor_comms._get_response.return_value = what
mocked_parse(what, "basic_dag", task)
- runtime_ti, context, log = startup()
+ runtime_ti, context, log = startup(get_startup_details())
assert runtime_ti is not None
assert isinstance(listener.component, TaskRunnerMarker)
del listener.component