This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f5ff9671ec8 Raise ``TaskAlreadyRunningError`` when starting an
already-running task instance (#60855)
f5ff9671ec8 is described below
commit f5ff9671ec81f9d259980a733766fef4c78ba91c
Author: Anish Giri <[email protected]>
AuthorDate: Wed Mar 25 16:25:03 2026 -0500
Raise ``TaskAlreadyRunningError`` when starting an already-running task
instance (#60855)
* Fix task marked as failed on executor redelivery
Handle 409 CONFLICT (task already running) from the API server gracefully
by raising TaskAlreadyRunningError instead of letting it propagate as a
generic failure.
closes: #58441
* Fix test_get_not_found assertion to match unwrapped detail format
* address review feed back
* Trigger CI re-run
* Trigger CI re-run
---
task-sdk/src/airflow/sdk/api/client.py | 20 ++++++-
task-sdk/src/airflow/sdk/exceptions.py | 4 ++
task-sdk/tests/task_sdk/api/test_client.py | 59 +++++++++++++++---
.../task_sdk/execution_time/test_supervisor.py | 70 +++++++++++-----------
4 files changed, 108 insertions(+), 45 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 90374f76be5..19e691281f7 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -76,7 +76,7 @@ from airflow.sdk.api.datamodels._generated import (
XComSequenceSliceResponse,
)
from airflow.sdk.configuration import conf
-from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
CreateHITLDetailPayload,
DRCount,
@@ -216,7 +216,18 @@ class TaskInstanceOperations:
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(),
unixname=getuser(), start_date=when)
- resp = self.client.patch(f"task-instances/{id}/run",
content=body.model_dump_json())
+ try:
+ resp = self.client.patch(f"task-instances/{id}/run",
content=body.model_dump_json())
+ except ServerResponseError as e:
+ if e.response.status_code == HTTPStatus.CONFLICT:
+ detail = e.detail
+ if (
+ isinstance(detail, dict)
+ and detail.get("reason") == "invalid_state"
+ and detail.get("previous_state") == "running"
+ ):
+ raise TaskAlreadyRunningError(f"Task instance {id} is
already running") from e
+ raise
return TIRunContext.model_validate_json(resp.read())
def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when:
datetime, rendered_map_index):
@@ -1034,7 +1045,7 @@ class Client(httpx.Client):
# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
- detail: list[RemoteValidationError] | str
+ detail: list[RemoteValidationError] | dict[str, Any] | str
def __repr__(self):
return repr(self.detail)
@@ -1068,6 +1079,9 @@ class ServerResponseError(httpx.HTTPStatusError):
if isinstance(body.detail, list):
detail = body.detail
msg = "Remote server returned validation error"
+ elif isinstance(body.detail, dict):
+ detail = body.detail
+ msg = "Server returned error"
else:
msg = body.detail or "Un-parseable error"
except Exception:
diff --git a/task-sdk/src/airflow/sdk/exceptions.py
b/task-sdk/src/airflow/sdk/exceptions.py
index 9f8a5f11fbf..b69abe62265 100644
--- a/task-sdk/src/airflow/sdk/exceptions.py
+++ b/task-sdk/src/airflow/sdk/exceptions.py
@@ -330,6 +330,10 @@ class TaskNotFound(AirflowException):
"""Raise when a Task is not available in the system."""
+class TaskAlreadyRunningError(AirflowException):
+ """Raised when a task is already running on another worker."""
+
+
class FailFastDagInvalidTriggerRule(AirflowException):
"""Raise when a dag has 'fail_fast' enabled yet has a non-default trigger
rule."""
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index 7d960b76570..0df8839c55f 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -47,7 +47,7 @@ from airflow.sdk.api.datamodels._generated import (
VariableResponse,
XComResponse,
)
-from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
DeferTask,
ErrorResponse,
@@ -161,7 +161,7 @@ class TestClient:
err = exc_info.value
assert err.args == ("Server returned error",)
- assert err.detail == {"detail": {"message": "Invalid input"}}
+ assert err.detail == {"message": "Invalid input"}
# Check that the error is picklable
pickled = pickle.dumps(err)
@@ -171,7 +171,7 @@ class TestClient:
# Test that unpickled error has the same attributes as the original
assert unpickled.response.json() == {"detail": {"message": "Invalid
input"}}
- assert unpickled.detail == {"detail": {"message": "Invalid input"}}
+ assert unpickled.detail == {"message": "Invalid input"}
assert unpickled.response.status_code == 404
assert unpickled.request.url == "http://error"
@@ -333,6 +333,53 @@ class TestTaskInstanceOperations:
assert resp == ti_context
assert call_count == 3
+ def test_task_instance_start_already_running(self):
+ """Test that start() raises TaskAlreadyRunningError when TI is already
running."""
+ ti_id = uuid6.uuid7()
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(
+ 409,
+ json={
+ "detail": {
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could
be marked as running",
+ "previous_state": "running",
+ }
+ },
+ )
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+
+ with pytest.raises(TaskAlreadyRunningError, match="already running"):
+ client.task_instances.start(ti_id, 100, datetime(2024, 10, 31,
tzinfo=timezone.utc))
+
+ @pytest.mark.parametrize("previous_state", ["failed", "success",
"skipped"])
+ def test_task_instance_start_other_invalid_states(self, previous_state):
+ """Test that start() raises ServerResponseError for non-running
invalid states."""
+ ti_id = uuid6.uuid7()
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(
+ 409,
+ json={
+ "detail": {
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could
be marked as running",
+ "previous_state": previous_state,
+ }
+ },
+ )
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+
+ with pytest.raises(ServerResponseError):
+ client.task_instances.start(ti_id, 100, datetime(2024, 10, 31,
tzinfo=timezone.utc))
+
@pytest.mark.parametrize(
"state", [state for state in TerminalTIState if state !=
TerminalTIState.SUCCESS]
)
@@ -1627,10 +1674,8 @@ class TestDagsOperations:
assert exc_info.value.response.status_code == 404
assert exc_info.value.detail == {
- "detail": {
- "message": "The Dag with dag_id: `missing_dag` was not found",
- "reason": "not_found",
- }
+ "message": "The Dag with dag_id: `missing_dag` was not found",
+ "reason": "not_found",
}
def test_get_server_error(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index ef6cd19b8d7..b486ce77766 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -63,7 +63,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskInstance,
TaskInstanceState,
)
-from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
+from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType,
TaskAlreadyRunningError
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
@@ -731,40 +731,6 @@ class TestWatchedSubprocess:
"task_instance_id": str(ti.id),
} in captured_logs
- def test_supervisor_handles_already_running_task(self):
- """Test that Supervisor prevents starting a Task Instance that is
already running."""
- ti = TaskInstance(
- id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1,
dag_version_id=uuid7()
- )
-
- # Mock API Server response indicating the TI is already running
- # The API Server would return a 409 Conflict status code if the TI is
not
- # in a "queued" state.
- def handle_request(request: httpx.Request) -> httpx.Response:
- if request.url.path == f"/task-instances/{ti.id}/run":
- return httpx.Response(
- 409,
- json={
- "reason": "invalid_state",
- "message": "TI was not in a state where it could be
marked as running",
- "previous_state": "running",
- },
- )
-
- return httpx.Response(status_code=204)
-
- client = make_client(transport=httpx.MockTransport(handle_request))
-
- with pytest.raises(ServerResponseError, match="Server returned error")
as err:
- ActivitySubprocess.start(dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE, what=ti, client=client)
-
- assert err.value.response.status_code == 409
- assert err.value.detail == {
- "reason": "invalid_state",
- "message": "TI was not in a state where it could be marked as
running",
- "previous_state": "running",
- }
-
@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True,
ids=["log_level=error"])
def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch,
mocker, make_ti_context_dict):
"""
@@ -865,6 +831,40 @@ class TestWatchedSubprocess:
},
]
+ def test_start_raises_task_already_running_and_kills_subprocess(self):
+ """Test that ActivitySubprocess.start() raises TaskAlreadyRunningError
and kills the child
+ when the API returns 409 with previous_state='running'."""
+ ti_id = uuid7()
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(
+ 409,
+ json={
+ "detail": {
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could
be marked as running",
+ "previous_state": "running",
+ }
+ },
+ )
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ # Ensure we follow the "protocol" and get the startup message
before we do anything
+ CommsDecoder()._get_response()
+
+ with pytest.raises(TaskAlreadyRunningError, match="already running"):
+ ActivitySubprocess.start(
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
+ what=TaskInstance(
+ id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1, dag_version_id=uuid7()
+ ),
+
client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
@pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True)
def test_heartbeat_failures_handling(self, monkeypatch, mocker,
captured_logs, time_machine):
"""