This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cfc2157d3cb AIP-72: Handling `failed` TI state for
`AirflowFailException` & `AirflowSensorTimeout` (#44954)
cfc2157d3cb is described below
commit cfc2157d3cbefe7b26bbccc604e1ca3357dd4c87
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Dec 18 13:04:44 2024 +0530
AIP-72: Handling `failed` TI state for `AirflowFailException` &
`AirflowSensorTimeout` (#44954)
related: https://github.com/apache/airflow/issues/44414
We already have support for handling terminal states from the task
execution side as well as the task SDK client side. (almost) and failed state
is part of the terminal state.
This PR extends the task runner's run function to handle cases when we have
to fail a task: `AirflowFailException, AirflowSensorTimeout`. It is
functionally very similar to #44786
As part of failing a task, multiple other things also needs to be done like:
- Callbacks: which will eventually be converted to teardown tasks
- Retries: Handled in https://github.com/apache/airflow/issues/44351
- unmapping TIs: https://github.com/apache/airflow/issues/44351
- Handling task history: will be handled by
https://github.com/apache/airflow/issues/44952
- Handling downstream tasks and non teardown tasks: will be handled by
https://github.com/apache/airflow/issues/44951
### Testing performed
#### End to End with Postman
1. Run airflow with breeze and run any DAG

2. Login to metadata DB and get the "id" for your task instance from TI
table

3. Send a request to `fail` your task

Or using curl:
```
curl --location --request PATCH
'http://localhost:29091/execution/task-instances/0193cec2-f46b-7348-9c27-9869d835dc7b/state'
\
--header 'Content-Type: application/json' \
--data '{
"state": "failed",
"end_date": "2024-10-31T12:00:00Z"
}'
```
4. Refresh back the Airflow UI to see that the task is in failed state.

---
.../execution_api/routes/task_instances.py | 4 ++
.../src/airflow/sdk/execution_time/task_runner.py | 11 ++++-
task_sdk/tests/execution_time/test_supervisor.py | 2 +
task_sdk/tests/execution_time/test_task_runner.py | 51 +++++++++++++++++++++-
.../execution_api/routes/test_task_instances.py | 30 +++++++++++++
5 files changed, 96 insertions(+), 2 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 3a1545283e8..ac3f80092a9 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -201,6 +201,10 @@ def ti_update_state(
if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ query = query.values(state=ti_patch_payload.state)
+ if ti_patch_payload.state == State.FAILED:
+ # clear the next_method and next_kwargs
+ query = query.values(next_method=None, next_kwargs=None)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 92f400d46e2..11341e76356 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
...
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
- ...
+ # If a sensor in reschedule mode reaches timeout, task should not
retry.
+
+ # TODO: Handle fail_stop here:
https://github.com/apache/airflow/issues/44951
+ # TODO: Handle addition to Log table:
https://github.com/apache/airflow/issues/44952
+ msg = TaskState(
+ state=TerminalTIState.FAILED,
+ end_date=datetime.now(tz=timezone.utc),
+ )
+
+ # TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
...
except SystemExit:
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 70f9e264864..51a31b8982f 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -854,6 +854,8 @@ class TestHandleRequest:
{"ok": True},
id="set_xcom_with_map_index",
),
+ # we aren't adding all states under TerminalTIState here, because
this test's scope is only to check
+ # if it can handle TaskState message
pytest.param(
TaskState(state=TerminalTIState.SKIPPED,
end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 2b812c92a73..35ff65414f8 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -26,7 +26,7 @@ from unittest import mock
import pytest
from uuid6 import uuid7
-from airflow.exceptions import AirflowSkipException
+from airflow.exceptions import AirflowFailException, AirflowSensorTimeout,
AirflowSkipException
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields,
StartupDetails, TaskState
@@ -333,6 +333,55 @@ def test_startup_dag_with_templated_fields(
)
[email protected](
+ ["dag_id", "task_id", "fail_with_exception"],
+ [
+ pytest.param(
+ "basic_failed", "fail-exception", AirflowFailException("Oops.
Failing by AirflowFailException!")
+ ),
+ pytest.param(
+ "basic_failed2",
+ "sensor-timeout-exception",
+ AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
+ ),
+ ],
+)
+def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id,
fail_with_exception, make_ti_context):
+ """Test running a basic task that marks itself as failed by raising
exception."""
+
+ class CustomOperator(BaseOperator):
+ def __init__(self, e, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.e = e
+
+ def execute(self, context):
+ print(f"raising exception {self.e}")
+ raise self.e
+
+ task = CustomOperator(task_id=task_id, e=fail_with_exception)
+
+ what = StartupDetails(
+ ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id,
run_id="c", try_number=1),
+ file="",
+ requests_fd=0,
+ ti_context=make_ti_context(),
+ )
+
+ ti = mocked_parse(what, dag_id, task)
+
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
+
+ with mock.patch(
+ "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+ ) as mock_supervisor_comms:
+ run(ti, log=mock.MagicMock())
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=TaskState(state=TerminalTIState.FAILED, end_date=instant),
log=mock.ANY
+ )
+
+
class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse,
make_ti_context):
"""Test get_template_context without ti_context_from_server."""
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index e67d82a718c..85b6d11ee3b 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -455,6 +455,36 @@ class TestTIHealthEndpoint:
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)
+ def test_ti_update_state_to_failed_table_check(self, client, session,
create_task_instance):
+ from math import ceil
+
+ ti = create_task_instance(
+ task_id="test_ti_update_state_to_failed_table_check",
+ state=State.RUNNING,
+ )
+ ti.start_date = DEFAULT_START_DATE
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": State.FAILED,
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ },
+ )
+
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ assert ti.state == State.FAILED
+ assert ti.next_method is None
+ assert ti.next_kwargs is None
+ # TODO: remove/amend this once
https://github.com/apache/airflow/pull/45002 is merged
+ assert ceil(ti.duration) == 3600.00
+
class TestTIPutRTIF:
def setup_method(self):