This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 490bb579767 Implement on_execute_callback in task sdk (#47989)
490bb579767 is described below
commit 490bb5797679636956d970deaf40a2b4403c783c
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Mar 20 19:22:30 2025 +0800
Implement on_execute_callback in task sdk (#47989)
The callable is run *outside* of the timeout block. Failures calling
any callback functions are simply logged without failing the entire
task. This matches Airflow 2's behavior.
Submitting this first to get feedback. Other callbacks will be
implemented in later PRs if I'm on the right track.
---
airflow/models/baseoperator.py | 4 --
.../src/airflow/sdk/definitions/baseoperator.py | 13 +++++--
.../src/airflow/sdk/execution_time/task_runner.py | 10 ++++-
.../task_sdk/execution_time/test_task_runner.py | 45 ++++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 2 +-
5 files changed, 64 insertions(+), 10 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index a37c28ef6a3..846a6e75502 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -339,7 +339,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator):
start_trigger_args: StartTriggerArgs | None = None
start_from_trigger: bool = False
- on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
@@ -349,7 +348,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator):
self,
pre_execute=None,
post_execute=None,
- on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
@@ -364,7 +362,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator):
super().__init__(**kwargs)
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute
- self.on_execute_callback = on_execute_callback
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_skipped_callback = on_skipped_callback
@@ -393,7 +390,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator):
return TaskSDKBaseOperator.get_serialized_fields() | {
"start_trigger_args",
"start_from_trigger",
- "on_execute_callback",
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py
b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
index b02eff27efb..ceceb7e26ab 100644
--- a/task-sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -69,6 +69,7 @@ from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import db_safe_priority
+C = TypeVar("C", bound=Callable)
T = TypeVar("T", bound=FunctionType)
if TYPE_CHECKING:
@@ -382,6 +383,12 @@ if "airflow.configuration" in sys.modules:
ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
+def _collect_callbacks(callbacks: C | Collection[C]) -> list[C]:
+ if isinstance(callbacks, Collection):
+ return list(callbacks)
+ return [callbacks]
+
+
class BaseOperatorMeta(abc.ABCMeta):
"""Metaclass of BaseOperator."""
@@ -805,7 +812,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
pool: str = DEFAULT_POOL_NAME
pool_slots: int = DEFAULT_POOL_SLOTS
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
- # on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ on_execute_callback: Sequence[TaskStateChangeCallback] = ()
# on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
# on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
# on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
@@ -959,7 +966,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
pool_slots: int = DEFAULT_POOL_SLOTS,
sla: timedelta | None = None,
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
- # on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
+ on_execute_callback: TaskStateChangeCallback |
Collection[TaskStateChangeCallback] = (),
# on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
# on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
# on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
@@ -1037,7 +1044,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.execution_timeout = execution_timeout
# TODO:
- # self.on_execute_callback = on_execute_callback
+ self.on_execute_callback = _collect_callbacks(on_execute_callback)
# self.on_failure_callback = on_failure_callback
# self.on_success_callback = on_success_callback
# self.on_retry_callback = on_retry_callback
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 7614ccbf113..fcd401ab5db 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -609,7 +609,7 @@ def run(
state = TerminalTIState.FAILED
return state, msg, error
- result = _execute_task(context, ti)
+ result = _execute_task(context, ti, log)
_push_xcom_if_needed(result, ti, log)
@@ -789,7 +789,7 @@ def _handle_trigger_dag_run(
return msg, state
-def _execute_task(context: Context, ti: RuntimeTaskInstance):
+def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
from airflow.exceptions import AirflowTaskTimeout
@@ -807,6 +807,12 @@ def _execute_task(context: Context, ti:
RuntimeTaskInstance):
# Populate the context var so ExecutorSafeguard doesn't complain
ctx.run(ExecutorSafeguard.tracker.set, task)
+ for i, callback in enumerate(task.on_execute_callback):
+ try:
+ callback(context)
+ except Exception:
+ log.exception("Failed to run on-execute callback", index=i,
callback=callback)
+
if task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 25696297811..10eb2bdf719 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1928,6 +1928,51 @@ class TestTaskRunnerCallsListeners:
assert listener.error == error
[email protected]("mock_supervisor_comms")
+class TestTaskRunnerCallsCallbacks:
+ def test_task_runner_calls_execute_callback(self, create_runtime_ti):
+ results = []
+
+ def custom_callback(context):
+ results.append("callback")
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ results.append("execute")
+
+ task = CustomOperator(task_id="task",
on_execute_callback=custom_callback)
+ runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+ log = mock.MagicMock()
+ state, _, _ = run(runtime_ti, log)
+
+ assert state == TerminalTIState.SUCCESS
+ assert results == ["callback", "execute"]
+
+ def test_task_runner_not_fail_on_failed_execute_callback(self,
create_runtime_ti):
+ results = []
+
+ def custom_callback_1(context):
+ results.append("callback 1")
+
+ def custom_callback_2(context):
+ raise Exception("sorry!")
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ results.append("execute")
+
+ task = CustomOperator(task_id="task",
on_execute_callback=[custom_callback_1, custom_callback_2])
+ runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+ log = mock.MagicMock()
+ state, _, _ = run(runtime_ti, log)
+
+ assert state == TerminalTIState.SUCCESS
+ assert results == ["callback 1", "execute"]
+ assert log.exception.mock_calls == [
+ mock.call("Failed to run on-execute callback", index=1,
callback=custom_callback_2),
+ ]
+
+
class TestTriggerDagRunOperator:
"""Tests to verify various aspects of TriggerDagRunOperator"""
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 8acde585817..d69ab7d0b81 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1357,7 +1357,7 @@ class TestStringifiedDAGs:
"max_active_tis_per_dag": None,
"max_active_tis_per_dagrun": None,
"max_retry_delay": None,
- "on_execute_callback": None,
+ "on_execute_callback": [],
"on_failure_fail_dagrun": False,
"on_failure_callback": None,
"on_retry_callback": None,