This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 581e2e42e9 Change AirflowTaskTimeout to inherit BaseException (#35653)
581e2e42e9 is described below
commit 581e2e42e947fc8f23ecccb89fbabccec9e8e26b
Author: HTErik <[email protected]>
AuthorDate: Wed Feb 21 18:43:17 2024 +0100
Change AirflowTaskTimeout to inherit BaseException (#35653)
Code that normally catches Exception should not implicitly ignore
interrupts from AirflowTaskTimout.
Fixes #35644 #35474
---
airflow/exceptions.py | 5 ++++-
airflow/models/taskinstance.py | 16 ++++++++--------
.../celery/executors/celery_executor_utils.py | 6 +++---
airflow/utils/context.pyi | 2 +-
newsfragments/35653.significant.rst | 21 +++++++++++++++++++++
tests/core/test_core.py | 9 ++++++++-
.../providers/microsoft/azure/hooks/test_synapse.py | 3 ++-
7 files changed, 47 insertions(+), 15 deletions(-)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index f747640c77..f2fae6e8d4 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -79,7 +79,10 @@ class InvalidStatsNameException(AirflowException):
"""Raise when name of the stats is invalid."""
-class AirflowTaskTimeout(AirflowException):
+# Important to inherit BaseException instead of AirflowException->Exception,
since this Exception is used
+# to explicitly interrupt ongoing task. Code that does normal error-handling
should not treat
+# such interrupt as an error that can be handled normally. (Compare with
KeyboardInterrupt)
+class AirflowTaskTimeout(BaseException):
"""Raise when the task execution times-out."""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 5d026a0667..cf5c97922e 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -812,7 +812,7 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance |
TaskInstancePydantic)
def _handle_failure(
*,
task_instance: TaskInstance | TaskInstancePydantic,
- error: None | str | Exception | KeyboardInterrupt,
+ error: None | str | BaseException,
session: Session,
test_mode: bool | None = None,
context: Context | None = None,
@@ -2411,7 +2411,7 @@ class TaskInstance(Base, LoggingMixin):
self.handle_failure(e, test_mode, context, force_fail=True,
session=session)
session.commit()
raise
- except AirflowException as e:
+ except (AirflowTaskTimeout, AirflowException) as e:
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
# for case when task is marked as success/failed externally
@@ -2426,10 +2426,6 @@ class TaskInstance(Base, LoggingMixin):
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
- except (Exception, KeyboardInterrupt) as e:
- self.handle_failure(e, test_mode, context, session=session)
- session.commit()
- raise
except SystemExit as e:
# We have already handled SystemExit with success codes (0 and
None) in the `_execute_task`.
# Therefore, here we must handle only error codes.
@@ -2437,6 +2433,10 @@ class TaskInstance(Base, LoggingMixin):
self.handle_failure(msg, test_mode, context, session=session)
session.commit()
raise Exception(msg)
+ except BaseException as e:
+ self.handle_failure(e, test_mode, context, session=session)
+ session.commit()
+ raise
finally:
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}",
tags=self.stats_tags)
# Same metric with tagging
@@ -2743,7 +2743,7 @@ class TaskInstance(Base, LoggingMixin):
def fetch_handle_failure_context(
cls,
ti: TaskInstance | TaskInstancePydantic,
- error: None | str | Exception | KeyboardInterrupt,
+ error: None | str | BaseException,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
@@ -2838,7 +2838,7 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def handle_failure(
self,
- error: None | str | Exception | KeyboardInterrupt,
+ error: None | str | BaseException,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
diff --git a/airflow/providers/celery/executors/celery_executor_utils.py
b/airflow/providers/celery/executors/celery_executor_utils.py
index 292bbc0c70..bd1725e6d3 100644
--- a/airflow/providers/celery/executors/celery_executor_utils.py
+++ b/airflow/providers/celery/executors/celery_executor_utils.py
@@ -41,7 +41,7 @@ from sqlalchemy import select
import airflow.settings as settings
from airflow.configuration import conf
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
@@ -198,7 +198,7 @@ class ExceptionWithTraceback:
:param exception_traceback: The stacktrace to wrap
"""
- def __init__(self, exception: Exception, exception_traceback: str):
+ def __init__(self, exception: BaseException, exception_traceback: str):
self.exception = exception
self.traceback = exception_traceback
@@ -211,7 +211,7 @@ def send_task_to_executor(
try:
with timeout(seconds=OPERATION_TIMEOUT):
result = task_to_run.apply_async(args=[command], queue=queue)
- except Exception as e:
+ except (Exception, AirflowTaskTimeout) as e:
exception_traceback = f"Celery Task ID:
{key}\n{traceback.format_exc()}"
result = ExceptionWithTraceback(e, exception_traceback)
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 256823dd0b..9fecccfb1d 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -65,7 +65,7 @@ class Context(TypedDict, total=False):
data_interval_start: DateTime
ds: str
ds_nodash: str
- exception: KeyboardInterrupt | Exception | str | None
+ exception: BaseException | str | None
execution_date: DateTime
expanded_ti_count: int | None
inlets: list
diff --git a/newsfragments/35653.significant.rst
b/newsfragments/35653.significant.rst
new file mode 100644
index 0000000000..ea93c83343
--- /dev/null
+++ b/newsfragments/35653.significant.rst
@@ -0,0 +1,21 @@
+``AirflowTimeoutError`` is no longer ``except``ed by default through
``Exception``
+
+The ``AirflowTimeoutError`` is now inheriting ``BaseException`` instead of
+``AirflowException``->``Exception``.
+See https://docs.python.org/3/library/exceptions.html#exception-hierarchy
+
+This prevents code catching ``Exception`` from accidentally
+catching ``AirflowTimeoutError`` and continuing to run.
+``AirflowTimeoutError`` is an explicit intent to cancel the task, and should
not
+be caught in attempts to handle the error and return some default value.
+
+Catching ``AirflowTimeoutError`` is still possible by explicitly ``except``ing
+``AirflowTimeoutError`` or ``BaseException``.
+This is discouraged, as it may allow the code to continue running even after
+such cancellation requests.
+Code that previously depended on performing strict cleanup in every situation
+after catching ``Exception`` is advised to use ``finally`` blocks or
+context managers. To perform only the cleanup and then automatically
+re-raise the exception.
+See similar considerations about catching ``KeyboardInterrupt`` in
+https://docs.python.org/3/library/exceptions.html#KeyboardInterrupt
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 5f37cb2db0..c687a352bd 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -71,11 +71,18 @@ class TestCore:
op.dry_run()
def test_timeout(self, dag_maker):
+ def sleep_and_catch_other_exceptions():
+ try:
+ sleep(5)
+ # Catching Exception should NOT catch AirflowTaskTimeout
+ except Exception:
+ pass
+
with dag_maker():
op = PythonOperator(
task_id="test_timeout",
execution_timeout=timedelta(seconds=1),
- python_callable=lambda: sleep(5),
+ python_callable=sleep_and_catch_other_exceptions,
)
dag_maker.create_dagrun()
with pytest.raises(AirflowTaskTimeout):
diff --git a/tests/providers/microsoft/azure/hooks/test_synapse.py
b/tests/providers/microsoft/azure/hooks/test_synapse.py
index d66268798d..9b116cd054 100644
--- a/tests/providers/microsoft/azure/hooks/test_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_synapse.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
import pytest
from azure.synapse.spark import SparkClient
+from airflow.exceptions import AirflowTaskTimeout
from airflow.models.connection import Connection
from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook,
AzureSynapseSparkBatchRunStatus
@@ -172,7 +173,7 @@ def test_wait_for_job_run_status(hook, job_run_status,
expected_status, expected
if expected_output != "timeout":
assert hook.wait_for_job_run_status(**config) == expected_output
else:
- with pytest.raises(Exception):
+ with pytest.raises(AirflowTaskTimeout):
hook.wait_for_job_run_status(**config)