This is an automated email from the ASF dual-hosted git repository. dstandish pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new b459af3ee0 Replace State usages with strong-typed enums (#31735) b459af3ee0 is described below commit b459af3ee0f94cd246e3d401ca3eec18ffd85db0 Author: Tzu-ping Chung <uranu...@gmail.com> AuthorDate: Sat Jun 17 04:48:57 2023 +0800 Replace State usages with strong-typed enums (#31735) Only in the main Airflow code base. There are many more in tests that I might tackle some day. Additionally, there are some cases where TI state is used for "job" state. We may deal with this later by introducing a new type ExecutorState. --- airflow/api/common/delete_dag.py | 4 +- airflow/api/common/mark_tasks.py | 30 +++++++++--- .../endpoints/task_instance_endpoint.py | 4 +- airflow/api_connexion/schemas/enum_schemas.py | 4 +- airflow/dag_processing/processor.py | 16 +++++-- airflow/executors/base_executor.py | 10 ++-- airflow/executors/celery_executor.py | 8 ++-- airflow/executors/celery_executor_utils.py | 8 ++-- airflow/executors/debug_executor.py | 28 +++++------ airflow/executors/kubernetes_executor.py | 24 ++++++---- airflow/executors/local_executor.py | 16 +++---- airflow/executors/sequential_executor.py | 6 +-- airflow/jobs/backfill_job_runner.py | 6 +-- airflow/jobs/job.py | 7 ++- airflow/jobs/local_task_job_runner.py | 6 +-- airflow/jobs/scheduler_job_runner.py | 28 +++++------ airflow/listeners/spec/taskinstance.py | 6 +-- airflow/models/dag.py | 18 ++++---- airflow/models/dagrun.py | 42 ++++++++--------- airflow/models/pool.py | 8 ++-- airflow/models/skipmixin.py | 4 +- airflow/models/taskinstance.py | 54 +++++++++++----------- airflow/operators/subdag.py | 20 ++++---- airflow/operators/trigger_dagrun.py | 6 +-- airflow/sensors/external_task.py | 28 +++++------ airflow/sentry.py | 4 +- airflow/ti_deps/dependencies_states.py | 30 ++++++------ airflow/ti_deps/deps/dagrun_exists_dep.py | 4 +- airflow/ti_deps/deps/not_in_retry_period_dep.py | 4 +- airflow/ti_deps/deps/not_previously_skipped_dep.py | 4 +- airflow/ti_deps/deps/prev_dagrun_dep.py | 4 +- airflow/ti_deps/deps/ready_to_reschedule.py | 4 +- airflow/ti_deps/deps/task_not_running_dep.py | 4 +- airflow/utils/dot_renderer.py | 2 +- airflow/utils/log/file_task_handler.py | 5 +- airflow/utils/log/log_reader.py | 23 +++++---- airflow/utils/state.py | 9 +--- airflow/www/utils.py | 10 ++-- airflow/www/views.py | 26 +++++------ 39 files changed, 280 insertions(+), 244 deletions(-) diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py index 45611729ea..1d879a667a 100644 --- a/airflow/api/common/delete_dag.py +++ b/airflow/api/common/delete_dag.py @@ -29,7 +29,7 @@ from airflow.models import DagModel, TaskFail from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.db import get_sqla_model_classes from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState log = logging.getLogger(__name__) @@ -50,7 +50,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session = running_tis = session.scalar( select(models.TaskInstance.state) .where(models.TaskInstance.dag_id == dag_id) - .where(models.TaskInstance.state == State.RUNNING) + .where(models.TaskInstance.state == TaskInstanceState.RUNNING) .limit(1) ) if running_tis: diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index b965237bdf..184251b515 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -155,7 +155,7 @@ def set_state( for task_instance in tis_altered: # The try_number was decremented when setting to up_for_reschedule and deferred. # Increment it back when changing the state again - if task_instance.state in [State.DEFERRED, State.UP_FOR_RESCHEDULE]: + if task_instance.state in (TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RESCHEDULE): task_instance._try_number += 1 task_instance.set_state(state, session=session) session.flush() @@ -362,7 +362,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) ).scalar_one() dag_run.state = state - if state == State.RUNNING: + if state == DagRunState.RUNNING: dag_run.start_date = timezone.utcnow() dag_run.end_date = None else: @@ -415,7 +415,13 @@ def set_dag_run_state_to_success( # Mark all task instances of the dag run to success. for task in dag.tasks: task.dag = dag - return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS, commit=commit, session=session) + return set_state( + tasks=dag.tasks, + run_id=run_id, + state=TaskInstanceState.SUCCESS, + commit=commit, + session=session, + ) @provide_session @@ -468,7 +474,9 @@ def set_dag_run_state_to_failed( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.task_id.in_(task_ids), - TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]), + TaskInstance.state.in_( + (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RESCHEDULE), + ), ) ) @@ -487,16 +495,24 @@ def set_dag_run_state_to_failed( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.state.not_in(State.finished), - TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]), + TaskInstance.state.not_in( + (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RESCHEDULE), + ), ) ) tis = [ti for ti in tis] if commit: for ti in tis: - ti.set_state(State.SKIPPED) + ti.set_state(TaskInstanceState.SKIPPED) - return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session) + return tis + set_state( + tasks=tasks, + run_id=run_id, + state=TaskInstanceState.FAILED, + commit=commit, + session=session, + ) def __set_dag_run_state_to_running_or_queued( diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 533d97c858..7ccc6b8848 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -50,7 +50,7 @@ from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import DagRunState, State +from airflow.utils.state import DagRunState T = TypeVar("T") @@ -264,7 +264,7 @@ def get_mapped_task_instances( def _convert_state(states: Iterable[str] | None) -> list[str | None] | None: if not states: return None - return [State.NONE if s == "none" else s for s in states] + return [None if s == "none" else s for s in states] def _apply_array_filter(query: Query, key: ClauseElement, values: Iterable[Any] | None) -> Query: diff --git a/airflow/api_connexion/schemas/enum_schemas.py b/airflow/api_connexion/schemas/enum_schemas.py index 981a3669b1..ba82010783 100644 --- a/airflow/api_connexion/schemas/enum_schemas.py +++ b/airflow/api_connexion/schemas/enum_schemas.py @@ -26,7 +26,7 @@ class DagStateField(fields.String): def __init__(self, **metadata): super().__init__(**metadata) - self.validators = [validate.OneOf(State.dag_states)] + list(self.validators) + self.validators = [validate.OneOf(State.dag_states), *self.validators] class TaskInstanceStateField(fields.String): @@ -34,4 +34,4 @@ class TaskInstanceStateField(fields.String): def __init__(self, **metadata): super().__init__(**metadata) - self.validators = [validate.OneOf(State.task_states)] + list(self.validators) + self.validators = [validate.OneOf(State.task_states), *self.validators] diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 9ee865bbf8..d742c5314a 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -56,7 +56,7 @@ from airflow.utils.file import iter_airflow_imports, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.models.operator import Operator @@ -432,9 +432,11 @@ class DagFileProcessor(LoggingMixin): qry = ( session.query(TI.task_id, func.max(DR.execution_date).label("max_ti")) .join(TI.dag_run) - .filter(TI.dag_id == dag.dag_id) - .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) - .filter(TI.task_id.in_(dag.task_ids)) + .filter( + TI.dag_id == dag.dag_id, + or_(TI.state == TaskInstanceState.SUCCESS, TI.state == TaskInstanceState.SKIPPED), + TI.task_id.in_(dag.task_ids), + ) .group_by(TI.task_id) .subquery("sq") ) @@ -500,7 +502,11 @@ class DagFileProcessor(LoggingMixin): sla_dates: list[datetime] = [sla.execution_date for sla in slas] fetched_tis: list[TI] = ( session.query(TI) - .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) + .filter( + TI.state != TaskInstanceState.SUCCESS, + TI.execution_date.in_(sla_dates), + TI.dag_id == dag.dag_id, + ) .all() ) blocking_tis: list[TI] = [] diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 9599adabdf..6ae80dced1 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -31,7 +31,7 @@ from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -54,7 +54,7 @@ if TYPE_CHECKING: # Event_buffer dict value type # Tuple of: state, info - EventBufferValueType = Tuple[Optional[str], Any] + EventBufferValueType = Tuple[Optional[TaskInstanceState], Any] # Task tuple to send to be executed TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]] @@ -298,7 +298,7 @@ class BaseExecutor(LoggingMixin): self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: """ Changes state of the task. @@ -320,7 +320,7 @@ class BaseExecutor(LoggingMixin): :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, State.FAILED, info) + self.change_state(key, TaskInstanceState.FAILED, info) def success(self, key: TaskInstanceKey, info=None) -> None: """ @@ -329,7 +329,7 @@ class BaseExecutor(LoggingMixin): :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, State.SUCCESS, info) + self.change_state(key, TaskInstanceState.SUCCESS, info) def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: """ diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index de59804a04..c9f4ded309 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -38,7 +38,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor from airflow.stats import Stats -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState log = logging.getLogger(__name__) @@ -150,7 +150,7 @@ class CeleryExecutor(BaseExecutor): self.task_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback) - self.event_buffer[key] = (State.FAILED, None) + self.event_buffer[key] = (TaskInstanceState.FAILED, None) elif result is not None: result.backend = cached_celery_backend self.running.add(key) @@ -159,7 +159,7 @@ class CeleryExecutor(BaseExecutor): # Store the Celery task_id in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at # which point we don't need the ID anymore anyway - self.event_buffer[key] = (State.QUEUED, result.task_id) + self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) # If the task runs _really quickly_ we may already have a result! self.update_task_state(key, result.state, getattr(result, "info", None)) @@ -206,7 +206,7 @@ class CeleryExecutor(BaseExecutor): if state: self.update_task_state(key, state, info) - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: super().change_state(key, state, info) self.tasks.pop(key, None) diff --git a/airflow/executors/celery_executor_utils.py b/airflow/executors/celery_executor_utils.py index 2c8af4cf91..80460c3c8a 100644 --- a/airflow/executors/celery_executor_utils.py +++ b/airflow/executors/celery_executor_utils.py @@ -46,6 +46,7 @@ from airflow.stats import Stats from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname +from airflow.utils.state import TaskInstanceState from airflow.utils.timeout import timeout log = logging.getLogger(__name__) @@ -192,9 +193,10 @@ def send_task_to_executor( return key, command, result -def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: - """ - Fetch and return the state of the given celery task. +def fetch_celery_task_state( + async_result: AsyncResult, +) -> tuple[str, TaskInstanceState | ExceptionWithTraceback, Any]: + """Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index ca23b09a67..8a46d6cda0 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -29,7 +29,7 @@ import time from typing import TYPE_CHECKING, Any from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor): while self.tasks_to_run: ti = self.tasks_to_run.pop(0) if self.fail_fast and not task_succeeded: - self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED) - ti.set_state(State.UPSTREAM_FAILED) - self.change_state(ti.key, State.UPSTREAM_FAILED) + self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED) + ti.set_state(TaskInstanceState.UPSTREAM_FAILED) + self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED) continue if self._terminated.is_set(): - self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED) - ti.set_state(State.FAILED) - self.change_state(ti.key, State.FAILED) + self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED) + ti.set_state(TaskInstanceState.FAILED) + self.change_state(ti.key, TaskInstanceState.FAILED) continue task_succeeded = self._run_task(ti) @@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor): try: params = self.tasks_params.pop(ti.key, {}) ti.run(job_id=ti.job_id, **params) - self.change_state(key, State.SUCCESS) + self.change_state(key, TaskInstanceState.SUCCESS) return True except Exception as e: - ti.set_state(State.FAILED) - self.change_state(key, State.FAILED) + ti.set_state(TaskInstanceState.FAILED) + self.change_state(key, TaskInstanceState.FAILED) self.log.exception("Failed to execute task: %s.", str(e)) return False @@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor): def end(self) -> None: """Set states of queued tasks to UPSTREAM_FAILED marking them as not executed.""" for ti in self.tasks_to_run: - self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED) - ti.set_state(State.UPSTREAM_FAILED) - self.change_state(ti.key, State.UPSTREAM_FAILED) + self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED) + ti.set_state(TaskInstanceState.UPSTREAM_FAILED) + self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED) def terminate(self) -> None: self._terminated.set() - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: self.log.debug("Popping %s from executor task queue.", key) self.running.remove(key) self.event_buffer[key] = state, info diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 8707bf9699..092a84f470 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -53,7 +53,7 @@ from airflow.kubernetes.pod_generator import PodGenerator from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.executors.base_executor import CommandType @@ -228,12 +228,16 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): # since kube server have received request to delete pod set TI state failed if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version), + ) else: self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string) elif status == "Failed": self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version), + ) elif status == "Succeeded": # We get multiple events once the pod hits a terminal state, and we only want to # send it along to the scheduler once. @@ -261,7 +265,9 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): pod_name, annotations_string, ) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version), + ) else: self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string) else: @@ -700,7 +706,7 @@ class KubernetesExecutor(BaseExecutor): last_resource_version[namespace] = resource_version self.log.info("Changing state of %s to %s", results, state) try: - self._change_state(key, state, pod_name, namespace) + self._change_state(key, TaskInstanceState(state), pod_name, namespace) except Exception as e: self.log.exception( "Exception: %s when attempting to change state of %s to %s, re-queueing.", @@ -767,7 +773,7 @@ class KubernetesExecutor(BaseExecutor): def _change_state( self, key: TaskInstanceKey, - state: str | None, + state: TaskInstanceState | None, pod_name: str, namespace: str, session: Session = NEW_SESSION, @@ -776,12 +782,12 @@ class KubernetesExecutor(BaseExecutor): assert self.kube_scheduler from airflow.models.taskinstance import TaskInstance - if state == State.RUNNING: + if state == TaskInstanceState.RUNNING: self.event_buffer[key] = state, None return if self.kube_config.delete_worker_pods: - if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure: + if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure: self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace) self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace)) else: @@ -1011,7 +1017,7 @@ class KubernetesExecutor(BaseExecutor): "Changing state of %s to %s : resource_version=%d", results, state, resource_version ) try: - self._change_state(key, state, pod_name, namespace) + self._change_state(key, TaskInstanceState(state), pod_name, namespace) except Exception as e: self.log.exception( "Ignoring exception: %s when attempting to change state of %s to %s.", diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 715bdb42ae..550a6519e1 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -39,7 +39,7 @@ from airflow import settings from airflow.exceptions import AirflowException from airflow.executors.base_executor import PARALLELISM, BaseExecutor from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.executors.base_executor import CommandType @@ -94,20 +94,20 @@ class LocalWorkerBase(Process, LoggingMixin): # Remove the command since the worker is done executing the task setproctitle("airflow worker -- LocalExecutor") - def _execute_work_in_subprocess(self, command: CommandType) -> str: + def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState: try: subprocess.check_call(command, close_fds=True) - return State.SUCCESS + return TaskInstanceState.SUCCESS except subprocess.CalledProcessError as e: self.log.error("Failed to execute task %s.", str(e)) - return State.FAILED + return TaskInstanceState.FAILED - def _execute_work_in_fork(self, command: CommandType) -> str: + def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState: pid = os.fork() if pid: # In parent, wait for the child pid, ret = os.waitpid(pid, 0) - return State.SUCCESS if ret == 0 else State.FAILED + return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED from airflow.sentry import Sentry @@ -130,10 +130,10 @@ class LocalWorkerBase(Process, LoggingMixin): args.func(args) ret = 0 - return State.SUCCESS + return TaskInstanceState.SUCCESS except Exception as e: self.log.exception("Failed to execute task %s.", e) - return State.FAILED + return TaskInstanceState.FAILED finally: Sentry.flush() logging.shutdown() diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index 28f88c6b87..2715edad6e 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -28,7 +28,7 @@ import subprocess from typing import TYPE_CHECKING, Any from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.executors.base_executor import CommandType @@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor): try: subprocess.check_call(command, close_fds=True) - self.change_state(key, State.SUCCESS) + self.change_state(key, TaskInstanceState.SUCCESS) except subprocess.CalledProcessError as e: - self.change_state(key, State.FAILED) + self.change_state(key, TaskInstanceState.FAILED) self.log.error("Failed to execute task %s.", str(e)) self.commands_to_run = [] diff --git a/airflow/jobs/backfill_job_runner.py b/airflow/jobs/backfill_job_runner.py index 40c31c4451..a0efcb17af 100644 --- a/airflow/jobs/backfill_job_runner.py +++ b/airflow/jobs/backfill_job_runner.py @@ -68,7 +68,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin): job_type = "BackfillJob" - STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED) + STATES_COUNT_AS_RUNNING = (TaskInstanceState.RUNNING, TaskInstanceState.QUEUED) @attr.define class _DagRunTaskStatus: @@ -219,7 +219,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin): # is changed externally, e.g. by clearing tasks from the ui. We need to cover # for that as otherwise those tasks would fall outside the scope of # the backfill suddenly. - elif ti.state == State.NONE: + elif ti.state is None: self.log.warning( "FIXME: task instance %s state was set to none externally or " "reaching concurrency limits. Re-adding task to queue.", @@ -1004,7 +1004,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin): ).all() for ti in reset_tis: - ti.state = State.NONE + ti.state = None session.merge(ti) return result + reset_tis diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py index fe75508f63..903cfb9aca 100644 --- a/airflow/jobs/job.py +++ b/airflow/jobs/job.py @@ -19,7 +19,7 @@ from __future__ import annotations from functools import cached_property from time import sleep -from typing import Callable, NoReturn +from typing import TYPE_CHECKING, Callable, NoReturn from sqlalchemy import Column, Index, Integer, String, case, select from sqlalchemy.exc import OperationalError @@ -42,6 +42,9 @@ from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State +if TYPE_CHECKING: + from airflow.executors.base_executor import BaseExecutor + def _resolve_dagrun_model(): from airflow.models.dagrun import DagRun @@ -117,7 +120,7 @@ class Job(Base, LoggingMixin): super().__init__(**kwargs) @cached_property - def executor(self): + def executor(self) -> BaseExecutor: return ExecutorLoader.get_default_executor() def is_alive(self, grace_multiplier=2.1): diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index 9f6a4b55e8..fd234e4150 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -35,7 +35,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.platform import IS_WINDOWS from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState SIGSEGV_MESSAGE = """ ******************************************* Received SIGSEGV ******************************************* @@ -243,7 +243,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin): self.task_instance.refresh_from_db() ti = self.task_instance - if ti.state == State.RUNNING: + if ti.state == TaskInstanceState.RUNNING: fqdn = get_hostname() same_hostname = fqdn == ti.hostname if not same_hostname: @@ -273,7 +273,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin): ) raise AirflowException("PID of job runner does not match") elif self.task_runner.return_code() is None and hasattr(self.task_runner, "process"): - if ti.state == State.SKIPPED: + if ti.state == TaskInstanceState.SKIPPED: # A DagRun timeout will cause tasks to be externally marked as skipped. dagrun = ti.get_dagrun(session=session) execution_time = (dagrun.end_date or timezone.utcnow()) - dagrun.start_date diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index a85a0925fb..16764b1980 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -610,7 +610,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): ) for ti in executable_tis: - ti.emit_state_change_metric(State.QUEUED) + ti.emit_state_change_metric(TaskInstanceState.QUEUED) for ti in executable_tis: make_transient(ti) @@ -626,7 +626,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): # actually enqueue them for ti in task_instances: if ti.dag_run.state in State.finished: - ti.set_state(State.NONE, session=session) + ti.set_state(None, session=session) continue command = ti.command_as_list( local=True, @@ -682,13 +682,13 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): # Report execution for ti_key, value in event_buffer.items(): - state: str + state: TaskInstanceState | None state, _ = value # We create map (dag_id, task_id, execution_date) -> in-memory try_number ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number self.log.info("Received executor event with state %s for task instance %s", state, ti_key) - if state in (State.FAILED, State.SUCCESS, State.QUEUED): + if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED): tis_with_right_state.append(ti_key) # Return if no finished tasks @@ -712,7 +712,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) - if state == State.QUEUED: + if state == TaskInstanceState.QUEUED: ti.external_executor_id = info self.log.info("Setting external_id for %s to %s", ti, info) continue @@ -1535,7 +1535,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): tasks_stuck_in_queued = session.scalars( select(TI).where( - TI.state == State.QUEUED, + TI.state == TaskInstanceState.QUEUED, TI.queued_dttm < (timezone.utcnow() - timedelta(seconds=self._task_queued_timeout)), TI.queued_by_job_id == self.job.id, ) @@ -1602,20 +1602,19 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): self.log.info("Marked %d SchedulerJob instances as failed", num_failed) Stats.incr(self.__class__.__name__.lower() + "_end", num_failed) - resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING] query = ( select(TI) - .where(TI.state.in_(resettable_states)) + .where(TI.state.in_((TaskInstanceState.QUEUED, TaskInstanceState.RUNNING))) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is # released. .outerjoin(TI.queued_by_job) - .where(or_(TI.queued_by_job_id.is_(None), Job.state != State.RUNNING)) + .where(or_(TI.queued_by_job_id.is_(None), Job.state != TaskInstanceState.RUNNING)) .join(TI.dag_run) .where( DagRun.run_type != DagRunType.BACKFILL_JOB, - DagRun.state == State.RUNNING, + DagRun.state == DagRunState.RUNNING, ) .options(load_only(TI.dag_id, TI.task_id, TI.run_id)) ) @@ -1630,7 +1629,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): reset_tis_message = [] for ti in to_reset: reset_tis_message.append(repr(ti)) - ti.state = State.NONE + ti.state = None ti.queued_by_job_id = None for ti in set(tis_to_reset_or_adopt) - set(to_reset): @@ -1697,12 +1696,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): .join(Job, TI.job_id == Job.id) .join(DM, TI.dag_id == DM.dag_id) .where(TI.state == TaskInstanceState.RUNNING) - .where( - or_( - Job.state != State.RUNNING, - Job.latest_heartbeat < limit_dttm, - ) - ) + .where(or_(Job.state != State.RUNNING, Job.latest_heartbeat < limit_dttm)) .where(Job.job_type == "LocalTaskJob") .where(TI.queued_by_job_id == self.job.id) ) diff --git a/airflow/listeners/spec/taskinstance.py b/airflow/listeners/spec/taskinstance.py index 78de8a5f62..56b4cb7322 100644 --- a/airflow/listeners/spec/taskinstance.py +++ b/airflow/listeners/spec/taskinstance.py @@ -34,18 +34,18 @@ hookspec = HookspecMarker("airflow") def on_task_instance_running( previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None ): - """Called when task state changes to RUNNING. Previous_state can be State.NONE.""" + """Called when task state changes to RUNNING. Previous state can be None.""" @hookspec def on_task_instance_success( previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None ): - """Called when task state changes to SUCCESS. Previous_state can be State.NONE.""" + """Called when task state changes to SUCCESS. Previous state can be None.""" @hookspec def on_task_instance_failed( previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None ): - """Called when task state changes to FAIL. Previous_state can be State.NONE.""" + """Called when task state changes to FAIL. Previous state can be None.""" diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b7361d6f3b..57b15e5f32 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -103,7 +103,7 @@ from airflow.utils.helpers import at_most_one, exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType if TYPE_CHECKING: @@ -1281,7 +1281,7 @@ class DAG(LoggingMixin): TI = TaskInstance qry = session.query(func.count(TI.task_id)).filter( TI.dag_id == self.dag_id, - TI.state == State.RUNNING, + TI.state == TaskInstanceState.RUNNING, ) return qry.scalar() >= self.max_active_tasks @@ -1368,7 +1368,7 @@ class DAG(LoggingMixin): :return: List of execution dates """ - runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) + runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING) active_dates = [] for run in runs: @@ -1388,9 +1388,9 @@ class DAG(LoggingMixin): # .count() is inefficient query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id) if only_running: - query = query.filter(DagRun.state == State.RUNNING) + query = query.filter(DagRun.state == DagRunState.RUNNING) else: - query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED})) + query = query.filter(DagRun.state.in_({DagRunState.RUNNING, DagRunState.QUEUED})) if external_trigger is not None: query = query.filter( @@ -2077,7 +2077,7 @@ class DAG(LoggingMixin): @provide_session def set_dag_runs_state( self, - state: str = State.RUNNING, + state: DagRunState = DagRunState.RUNNING, session: Session = NEW_SESSION, start_date: datetime | None = None, end_date: datetime | None = None, @@ -2160,10 +2160,10 @@ class DAG(LoggingMixin): state = [] if only_failed: - state += [State.FAILED, State.UPSTREAM_FAILED] + state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] if only_running: # Yes, having `+=` doesn't make sense, but this was the existing behaviour - state += [State.RUNNING] + state += [TaskInstanceState.RUNNING] tis = self._get_task_instances( task_ids=task_ids, @@ -2694,7 +2694,7 @@ class DAG(LoggingMixin): # Instead of starting a scheduler, we run the minimal loop possible to check # for task readiness and dependency management. This is notably faster # than creating a BackfillJob and allows us to surface logs to the user - while dr.state == State.RUNNING: + while dr.state == DagRunState.RUNNING: schedulable_tis, _ = dr.update_state(session=session) try: for ti in schedulable_tis: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 45989be15a..bd94977599 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -111,7 +111,7 @@ class DagRun(Base, LoggingMixin): execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) - _state = Column("state", String(50), default=State.QUEUED) + _state = Column("state", String(50), default=DagRunState.QUEUED) run_id = Column(StringID(), nullable=False) creating_job_id = Column(Integer) external_trigger = Column(Boolean, default=True) @@ -218,7 +218,7 @@ class DagRun(Base, LoggingMixin): if state is not None: self.state = state if queued_at is NOTSET: - self.queued_at = timezone.utcnow() if state == State.QUEUED else None + self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None else: self.queued_at = queued_at self.run_type = run_type @@ -256,7 +256,7 @@ class DagRun(Base, LoggingMixin): if self._state != state: self._state = state self.end_date = timezone.utcnow() if self._state in State.finished else None - if state == State.QUEUED: + if state == DagRunState.QUEUED: self.queued_at = timezone.utcnow() @declared_attr @@ -289,9 +289,9 @@ class DagRun(Base, LoggingMixin): # because SQLAlchemy doesn't accept a set here. query = query.filter(cls.dag_id.in_(set(dag_ids))) if only_running: - query = query.filter(cls.state == State.RUNNING) + query = query.filter(cls.state == DagRunState.RUNNING) else: - query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED])) + query = query.filter(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED))) query = query.group_by(cls.dag_id) return {dag_id: count for dag_id, count in query.all()} @@ -323,7 +323,7 @@ class DagRun(Base, LoggingMixin): .join(DagModel, DagModel.dag_id == cls.dag_id) .filter(DagModel.is_paused == false(), DagModel.is_active == true()) ) - if state == State.QUEUED: + if state == DagRunState.QUEUED: # For dag runs in the queued state, we check if they have reached the max_active_runs limit # and if so we drop them running_drs = ( @@ -462,7 +462,7 @@ class DagRun(Base, LoggingMixin): tis = tis.filter(TI.state == state) else: # this is required to deal with NULL values - if State.NONE in state: + if None in state: if all(x is None for x in state): tis = tis.filter(TI.state.is_(None)) else: @@ -734,9 +734,9 @@ class DagRun(Base, LoggingMixin): try: ti.task = dag.get_task(ti.task_id) except TaskNotFound: - if ti.state != State.REMOVED: + if ti.state != TaskInstanceState.REMOVED: self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED session.flush() else: yield ti @@ -945,7 +945,7 @@ class DagRun(Base, LoggingMixin): self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True) def _emit_duration_stats_for_finished_state(self): - if self.state == State.RUNNING: + if self.state == DagRunState.RUNNING: return if self.start_date is None: self.log.warning("Failed to record duration of %s: start_date is not set.", self) @@ -1017,22 +1017,22 @@ class DagRun(Base, LoggingMixin): try: task = dag.get_task(ti.task_id) - should_restore_task = (task is not None) and ti.state == State.REMOVED + should_restore_task = (task is not None) and ti.state == TaskInstanceState.REMOVED if should_restore_task: self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", tags=self.stats_tags) # Same metric with tagging Stats.incr("task_restored_to_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) - ti.state = State.NONE + ti.state = None except AirflowException: - if ti.state == State.REMOVED: + if ti.state == TaskInstanceState.REMOVED: pass # ti has already been removed, just ignore it - elif self.state != State.RUNNING and not dag.partial: + elif self.state != DagRunState.RUNNING and not dag.partial: self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", tags=self.stats_tags) # Same metric with tagging Stats.incr("task_removed_from_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED continue try: @@ -1049,7 +1049,7 @@ class DagRun(Base, LoggingMixin): self.log.debug( "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti ) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED continue # Upstreams finished, check there aren't any extras if ti.map_index >= total_length: @@ -1058,7 +1058,7 @@ class DagRun(Base, LoggingMixin): ti, total_length, ) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED else: # Check if the number of mapped literals has changed, and we need to mark this TI as removed. if ti.map_index >= num_mapped_tis: @@ -1067,10 +1067,10 @@ class DagRun(Base, LoggingMixin): ti, num_mapped_tis, ) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED elif ti.map_index < 0: self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) - ti.state = State.REMOVED + ti.state = TaskInstanceState.REMOVED return task_ids @@ -1342,7 +1342,7 @@ class DagRun(Base, LoggingMixin): TI.run_id == self.run_id, tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk), ) - .update({TI.state: State.SCHEDULED}, synchronize_session=False) + .update({TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False) ) # Tasks using EmptyOperator should not be executed, mark them as success @@ -1358,7 +1358,7 @@ class DagRun(Base, LoggingMixin): ) .update( { - TI.state: State.SUCCESS, + TI.state: TaskInstanceState.SUCCESS, TI.start_date: timezone.utcnow(), TI.end_date: timezone.utcnow(), TI.duration: 0, diff --git a/airflow/models/pool.py b/airflow/models/pool.py index d1766d4a0a..cfad662691 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -28,7 +28,7 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import nowait, with_row_locks -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class PoolStats(TypedDict): @@ -245,7 +245,7 @@ class Pool(Base): return int( session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) - .filter(TaskInstance.state == State.RUNNING) + .filter(TaskInstance.state == TaskInstanceState.RUNNING) .scalar() or 0 ) @@ -263,7 +263,7 @@ class Pool(Base): return int( session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) - .filter(TaskInstance.state == State.QUEUED) + .filter(TaskInstance.state == TaskInstanceState.QUEUED) .scalar() or 0 ) @@ -281,7 +281,7 @@ class Pool(Base): return int( session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) - .filter(TaskInstance.state == State.SCHEDULED) + .filter(TaskInstance.state == TaskInstanceState.SCHEDULED) .scalar() or 0 ) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index d75a4a0e4d..73a733bd03 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -26,7 +26,7 @@ from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from pendulum import DateTime @@ -72,7 +72,7 @@ class SkipMixin(LoggingMixin): TaskInstance.task_id.in_(d.task_id for d in tasks), ).update( { - TaskInstance.state: State.SKIPPED, + TaskInstance.state: TaskInstanceState.SKIPPED, TaskInstance.start_date: now, TaskInstance.end_date: now, }, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 28dc168ec3..829c630786 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -591,7 +591,7 @@ class TaskInstance(Base, LoggingMixin): database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. - if self.state == State.RUNNING: + if self.state == TaskInstanceState.RUNNING: return self._try_number return self._try_number + 1 @@ -804,7 +804,7 @@ class TaskInstance(Base, LoggingMixin): :param session: SQLAlchemy ORM Session """ self.log.error("Recording the task instance as FAILED") - self.state = State.FAILED + self.state = TaskInstanceState.FAILED session.merge(self) session.commit() @@ -931,7 +931,7 @@ class TaskInstance(Base, LoggingMixin): self.log.debug("Setting task state for %s to %s", self, state) self.state = state self.start_date = self.start_date or current_time - if self.state in State.finished or self.state == State.UP_FOR_RETRY: + if self.state in State.finished or self.state == TaskInstanceState.UP_FOR_RETRY: self.end_date = self.end_date or current_time self.duration = (self.end_date - self.start_date).total_seconds() session.merge(self) @@ -944,7 +944,7 @@ class TaskInstance(Base, LoggingMixin): has elapsed. """ # is the task still in the retry waiting period? - return self.state == State.UP_FOR_RETRY and not self.ready_for_retry() + return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry() @provide_session def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: @@ -967,7 +967,7 @@ class TaskInstance(Base, LoggingMixin): TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(task.downstream_task_ids), TaskInstance.run_id == self.run_id, - TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]), + TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)), ) count = ti[0][0] return count == len(task.downstream_task_ids) @@ -1204,7 +1204,7 @@ class TaskInstance(Base, LoggingMixin): Checks on whether the task instance is in the right state and timeframe to be retried. """ - return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() + return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() @provide_session def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: @@ -1273,7 +1273,7 @@ class TaskInstance(Base, LoggingMixin): self.hostname = get_hostname() self.pid = None - if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: + if not ignore_all_deps and not ignore_ti_state and self.state == TaskInstanceState.SUCCESS: Stats.incr("previously_succeeded", tags=self.stats_tags) if not mark_success: @@ -1301,7 +1301,7 @@ class TaskInstance(Base, LoggingMixin): # start date that is recorded in task_reschedule table # If the task continues after being deferred (next_method is set), use the original start_date self.start_date = self.start_date if self.next_method else timezone.utcnow() - if self.state == State.UP_FOR_RESCHEDULE: + if self.state == TaskInstanceState.UP_FOR_RESCHEDULE: task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() if task_reschedule: self.start_date = task_reschedule.start_date @@ -1319,7 +1319,7 @@ class TaskInstance(Base, LoggingMixin): description="requeueable deps", ) if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): - self.state = State.NONE + self.state = None self.log.warning( "Rescheduling due to concurrency limits reached " "at task runtime. Attempt %s of " @@ -1339,10 +1339,10 @@ class TaskInstance(Base, LoggingMixin): self._try_number += 1 if not test_mode: - session.add(Log(State.RUNNING, self)) + session.add(Log(TaskInstanceState.RUNNING, self)) - self.state = State.RUNNING - self.emit_state_change_metric(State.RUNNING) + self.state = TaskInstanceState.RUNNING + self.emit_state_change_metric(TaskInstanceState.RUNNING) self.external_executor_id = external_executor_id self.end_date = None if not test_mode: @@ -1398,7 +1398,7 @@ class TaskInstance(Base, LoggingMixin): return # switch on state and deduce which metric to send - if new_state == State.RUNNING: + if new_state == TaskInstanceState.RUNNING: metric_name = "queued_duration" if self.queued_dttm is None: # this should not really happen except in tests or rare cases, @@ -1410,7 +1410,7 @@ class TaskInstance(Base, LoggingMixin): ) return timing = (timezone.utcnow() - self.queued_dttm).total_seconds() - elif new_state == State.QUEUED: + elif new_state == TaskInstanceState.QUEUED: metric_name = "scheduled_duration" if self.start_date is None: # same comment as above @@ -1497,7 +1497,7 @@ class TaskInstance(Base, LoggingMixin): self._execute_task_with_callbacks(context, test_mode) if not test_mode: self.refresh_from_db(lock_for_update=True, session=session) - self.state = State.SUCCESS + self.state = TaskInstanceState.SUCCESS except TaskDeferred as defer: # The task has signalled it wants to defer execution based on # a trigger. @@ -1521,7 +1521,7 @@ class TaskInstance(Base, LoggingMixin): self.log.info(e) if not test_mode: self.refresh_from_db(lock_for_update=True, session=session) - self.state = State.SKIPPED + self.state = TaskInstanceState.SKIPPED except AirflowRescheduleException as reschedule_exception: self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) session.commit() @@ -1744,7 +1744,7 @@ class TaskInstance(Base, LoggingMixin): # Then, update ourselves so it matches the deferral request # Keep an eye on the logic in `check_and_change_state_before_execution()` # depending on self.next_method semantics - self.state = State.DEFERRED + self.state = TaskInstanceState.DEFERRED self.trigger_id = trigger_row.id self.next_method = defer.method_name self.next_kwargs = defer.kwargs or {} @@ -1863,7 +1863,7 @@ class TaskInstance(Base, LoggingMixin): ) # set state - self.state = State.UP_FOR_RESCHEDULE + self.state = TaskInstanceState.UP_FOR_RESCHEDULE # Decrement try_number so subsequent runs will use the same try number and write # to same log file. @@ -1928,7 +1928,7 @@ class TaskInstance(Base, LoggingMixin): Stats.incr("ti_failures", tags=self.stats_tags) if not test_mode: - session.add(Log(State.FAILED, self)) + session.add(Log(TaskInstanceState.FAILED, self)) # Log failure duration session.add(TaskFail(ti=self)) @@ -1962,7 +1962,7 @@ class TaskInstance(Base, LoggingMixin): self.log.error("Unable to unmap task to determine if we need to send an alert email") if force_fail or not self.is_eligible_to_retry(): - self.state = State.FAILED + self.state = TaskInstanceState.FAILED email_for_state = operator.attrgetter("email_on_failure") callbacks = task.on_failure_callback if task else None callback_type = "on_failure" @@ -1971,10 +1971,10 @@ class TaskInstance(Base, LoggingMixin): tis = self.get_dagrun(session).get_task_instances() stop_all_tasks_in_dag(tis, session, self.task_id) else: - if self.state == State.QUEUED: + if self.state == TaskInstanceState.QUEUED: # We increase the try_number so as to fail the task if it fails to start after sometime self._try_number += 1 - self.state = State.UP_FOR_RETRY + self.state = TaskInstanceState.UP_FOR_RETRY email_for_state = operator.attrgetter("email_on_retry") callbacks = task.on_retry_callback if task else None callback_type = "on_retry" @@ -1995,7 +1995,7 @@ class TaskInstance(Base, LoggingMixin): def is_eligible_to_retry(self): """Is task instance is eligible for retry.""" - if self.state == State.RESTARTING: + if self.state == TaskInstanceState.RESTARTING: # If a task is cleared when running, it goes into RESTARTING state and is always # eligible for retry return True @@ -2336,8 +2336,8 @@ class TaskInstance(Base, LoggingMixin): 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' ) - # This function is called after changing the state from State.RUNNING, - # so we need to subtract 1 from self.try_number here. + # This function is called after changing the state from RUNNING, so we + # need to subtract 1 from try_number here. current_try_number = self.try_number - 1 additional_context: dict[str, Any] = { "exception": exception, @@ -2563,7 +2563,7 @@ class TaskInstance(Base, LoggingMixin): num_running_task_instances_query = session.query(func.count()).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, - TaskInstance.state == State.RUNNING, + TaskInstance.state == TaskInstanceState.RUNNING, ) if same_dagrun: num_running_task_instances_query = num_running_task_instances_query.filter( @@ -2885,7 +2885,7 @@ def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: # State of the task instance. # Stores string version of the task state. -TaskInstanceStateType = Tuple[TaskInstanceKey, str] +TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState] class SimpleTaskInstance: diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py index de7afcb4ea..af784d5d32 100644 --- a/airflow/operators/subdag.py +++ b/airflow/operators/subdag.py @@ -36,7 +36,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.sensors.base import BaseSensorOperator from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType @@ -137,17 +137,17 @@ class SubDagOperator(BaseSensorOperator): :param execution_date: Execution date to select task instances. """ with create_session() as session: - dag_run.state = State.RUNNING + dag_run.state = DagRunState.RUNNING session.merge(dag_run) failed_task_instances = ( session.query(TaskInstance) .filter(TaskInstance.dag_id == self.subdag.dag_id) .filter(TaskInstance.execution_date == execution_date) - .filter(TaskInstance.state.in_([State.FAILED, State.UPSTREAM_FAILED])) + .filter(TaskInstance.state.in_((TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED))) ) for task_instance in failed_task_instances: - task_instance.state = State.NONE + task_instance.state = None session.merge(task_instance) session.commit() @@ -164,7 +164,7 @@ class SubDagOperator(BaseSensorOperator): dag_run = self.subdag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=execution_date, - state=State.RUNNING, + state=DagRunState.RUNNING, conf=self.conf, external_trigger=True, data_interval=data_interval, @@ -172,13 +172,13 @@ class SubDagOperator(BaseSensorOperator): self.log.info("Created DagRun: %s", dag_run.run_id) else: self.log.info("Found existing DagRun: %s", dag_run.run_id) - if dag_run.state == State.FAILED: + if dag_run.state == DagRunState.FAILED: self._reset_dag_run_and_task_instances(dag_run, execution_date) def poke(self, context: Context): execution_date = context["execution_date"] dag_run = self._get_dagrun(execution_date=execution_date) - return dag_run.state != State.RUNNING + return dag_run.state != DagRunState.RUNNING def post_execute(self, context, result=None): super().post_execute(context) @@ -186,7 +186,7 @@ class SubDagOperator(BaseSensorOperator): dag_run = self._get_dagrun(execution_date=execution_date) self.log.info("Execution finished. State is %s", dag_run.state) - if dag_run.state != State.SUCCESS: + if dag_run.state != DagRunState.SUCCESS: raise AirflowException(f"Expected state: SUCCESS. Actual state: {dag_run.state}") if self.propagate_skipped_state and self._check_skipped_states(context): @@ -196,9 +196,9 @@ class SubDagOperator(BaseSensorOperator): leaves_tis = self._get_leaves_tis(context["execution_date"]) if self.propagate_skipped_state == SkippedStatePropagationOptions.ANY_LEAF: - return any(ti.state == State.SKIPPED for ti in leaves_tis) + return any(ti.state == TaskInstanceState.SKIPPED for ti in leaves_tis) if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES: - return all(ti.state == State.SKIPPED for ti in leaves_tis) + return all(ti.state == TaskInstanceState.SKIPPED for ti in leaves_tis) raise AirflowException( f"Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used." ) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index f3560bffc7..a3a6bf7c7c 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -36,7 +36,7 @@ from airflow.utils import timezone from airflow.utils.context import Context from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso" @@ -116,8 +116,8 @@ class TriggerDagRunOperator(BaseOperator): self.reset_dag_run = reset_dag_run self.wait_for_completion = wait_for_completion self.poke_interval = poke_interval - self.allowed_states = allowed_states or [State.SUCCESS] - self.failed_states = failed_states or [State.FAILED] + self.allowed_states = allowed_states or [DagRunState.SUCCESS] + self.failed_states = failed_states or [DagRunState.FAILED] self._defer = deferrable if execution_date is not None and not isinstance(execution_date, (str, datetime.datetime)): diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 959ebe5131..158032a2cc 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -37,7 +37,7 @@ from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import tuple_in_condition -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: from sqlalchemy.orm import Query, Session @@ -76,23 +76,25 @@ class ExternalTaskSensor(BaseSensorOperator): without also having to clear the sensor). By default, the ExternalTaskSensor will not skip if the external task skips. - To change this, simply set ``skipped_states=[State.SKIPPED]``. Note that if - you are monitoring multiple tasks, and one enters error state and the other - enters a skipped state, then the external task will react to whichever one - it sees first. If both happen together, then the failed state takes priority. + To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``. + Note that if you are monitoring multiple tasks, and one enters error state + and the other enters a skipped state, then the external task will react to + whichever one it sees first. If both happen together, then the failed state + takes priority. It is possible to alter the default behavior by setting states which - cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]`` - and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a - sensor which goes green when the external task *fails* and immediately goes - red if the external task *succeeds*! + cause the sensor to fail, e.g. by setting + ``allowed_states=[TaskInstanceState.FAILED]`` and + ``failed_states=[TaskInstanceState.SUCCESS]``, you will flip the behaviour + to get a sensor which goes green when the external task *fails* and + immediately goes red if the external task *succeeds*! Note that ``soft_fail`` is respected when examining the failed_states. Thus if the external task enters a failed state and ``soft_fail == True`` the sensor will _skip_ rather than fail. As a result, setting ``soft_fail=True`` - and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if - the external task skips. However, this is a contrived example - consider - using ``skipped_states`` if you would like this behaviour. Using + and ``failed_states=[TaskInstanceState.SKIPPED]`` will result in the sensor + skipping if the external task skips. However, this is a contrived example; + consider using ``skipped_states`` if you would like this behaviour. Using ``skipped_states`` allows the sensor to skip if the target fails, but still enter failed state on timeout. Using ``soft_fail == True`` as above will cause the sensor to skip if the target fails, but also if it times out. @@ -146,7 +148,7 @@ class ExternalTaskSensor(BaseSensorOperator): **kwargs, ): super().__init__(**kwargs) - self.allowed_states = list(allowed_states) if allowed_states else [State.SUCCESS] + self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS] self.skipped_states = list(skipped_states) if skipped_states else [] self.failed_states = list(failed_states) if failed_states else [] diff --git a/airflow/sentry.py b/airflow/sentry.py index 8742552e7e..fbc1715eed 100644 --- a/airflow/sentry.py +++ b/airflow/sentry.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING from airflow.configuration import conf from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.session import find_session_idx, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -143,7 +143,7 @@ if conf.getboolean("sentry", "sentry_on", fallback=False): return dr = task_instance.get_dagrun(session) task_instances = dr.get_task_instances( - state={State.SUCCESS, State.FAILED}, + state={TaskInstanceState.SUCCESS, TaskInstanceState.FAILED}, session=session, ) diff --git a/airflow/ti_deps/dependencies_states.py b/airflow/ti_deps/dependencies_states.py index 543ce3528a..fd25d62f6d 100644 --- a/airflow/ti_deps/dependencies_states.py +++ b/airflow/ti_deps/dependencies_states.py @@ -16,38 +16,38 @@ # under the License. from __future__ import annotations -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState EXECUTION_STATES = { - State.RUNNING, - State.QUEUED, + TaskInstanceState.RUNNING, + TaskInstanceState.QUEUED, } # In order to be able to get queued a task must have one of these states SCHEDULEABLE_STATES = { - State.NONE, - State.UP_FOR_RETRY, - State.UP_FOR_RESCHEDULE, + None, + TaskInstanceState.UP_FOR_RETRY, + TaskInstanceState.UP_FOR_RESCHEDULE, } RUNNABLE_STATES = { # For cases like unit tests and run manually - State.NONE, - State.UP_FOR_RETRY, - State.UP_FOR_RESCHEDULE, + None, + TaskInstanceState.UP_FOR_RETRY, + TaskInstanceState.UP_FOR_RESCHEDULE, # For normal scheduler/backfill cases - State.QUEUED, + TaskInstanceState.QUEUED, } QUEUEABLE_STATES = { - State.SCHEDULED, + TaskInstanceState.SCHEDULED, } BACKFILL_QUEUEABLE_STATES = { # For cases like unit tests and run manually - State.NONE, - State.UP_FOR_RESCHEDULE, - State.UP_FOR_RETRY, + None, + TaskInstanceState.UP_FOR_RESCHEDULE, + TaskInstanceState.UP_FOR_RETRY, # For normal backfill cases - State.SCHEDULED, + TaskInstanceState.SCHEDULED, } diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py index 781ab0ebaf..0a364628c7 100644 --- a/airflow/ti_deps/deps/dagrun_exists_dep.py +++ b/airflow/ti_deps/deps/dagrun_exists_dep.py @@ -19,7 +19,7 @@ from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState class DagrunRunningDep(BaseTIDep): @@ -31,7 +31,7 @@ class DagrunRunningDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context): dr = ti.get_dagrun(session) - if dr.state != State.RUNNING: + if dr.state != DagRunState.RUNNING: yield self._failing_status( reason=f"Task instance's dagrun was not in the 'running' state but in the state '{dr.state}'." ) diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py b/airflow/ti_deps/deps/not_in_retry_period_dep.py index b3b5d4ec56..90954f29f2 100644 --- a/airflow/ti_deps/deps/not_in_retry_period_dep.py +++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py @@ -20,7 +20,7 @@ from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils import timezone from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class NotInRetryPeriodDep(BaseTIDep): @@ -38,7 +38,7 @@ class NotInRetryPeriodDep(BaseTIDep): ) return - if ti.state != State.UP_FOR_RETRY: + if ti.state != TaskInstanceState.UP_FOR_RETRY: yield self._passing_status(reason="The task instance was not marked for retrying.") return diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index fdc8274d90..855f04af53 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -40,7 +40,7 @@ class NotPreviouslySkippedDep(BaseTIDep): XCOM_SKIPMIXIN_SKIPPED, SkipMixin, ) - from airflow.utils.state import State + from airflow.utils.state import TaskInstanceState upstream = ti.task.get_direct_relatives(upstream=True) @@ -87,7 +87,7 @@ class NotPreviouslySkippedDep(BaseTIDep): reason=("Task should be skipped but the the past depends are not met") ) return - ti.set_state(State.SKIPPED, session) + ti.set_state(TaskInstanceState.SKIPPED, session) yield self._failing_status( reason=f"Skipping because of previous XCom result from parent task {parent.task_id}" ) diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py index b3165eb7f2..62acdbca33 100644 --- a/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -22,7 +22,7 @@ from sqlalchemy import func from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class PrevDagrunDep(BaseTIDep): @@ -107,7 +107,7 @@ class PrevDagrunDep(BaseTIDep): ) return - if previous_ti.state not in {State.SKIPPED, State.SUCCESS}: + if previous_ti.state not in {TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS}: yield self._failing_status( reason=( f"depends_on_past is true for this task, but the previous task instance {previous_ti} " diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index 66aa5c5613..8394907081 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -22,7 +22,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils import timezone from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class ReadyToRescheduleDep(BaseTIDep): @@ -31,7 +31,7 @@ class ReadyToRescheduleDep(BaseTIDep): NAME = "Ready To Reschedule" IGNORABLE = True IS_TASK_DEP = True - RESCHEDULEABLE_STATES = {State.UP_FOR_RESCHEDULE, State.NONE} + RESCHEDULEABLE_STATES = {None, TaskInstanceState.UP_FOR_RESCHEDULE} @provide_session def _get_dep_statuses(self, ti, session, dep_context): diff --git a/airflow/ti_deps/deps/task_not_running_dep.py b/airflow/ti_deps/deps/task_not_running_dep.py index fd76873466..7299a4f3e2 100644 --- a/airflow/ti_deps/deps/task_not_running_dep.py +++ b/airflow/ti_deps/deps/task_not_running_dep.py @@ -20,7 +20,7 @@ from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class TaskNotRunningDep(BaseTIDep): @@ -37,7 +37,7 @@ class TaskNotRunningDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context=None): - if ti.state != State.RUNNING: + if ti.state != TaskInstanceState.RUNNING: yield self._passing_status(reason="Task is not in running state.") return diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py index 3f35d1b21d..d9c2acce80 100644 --- a/airflow/utils/dot_renderer.py +++ b/airflow/utils/dot_renderer.py @@ -58,7 +58,7 @@ def _draw_task( ) -> None: """Draw a single task on the given parent_graph.""" if states_by_task_id: - state = states_by_task_id.get(task.task_id, State.NONE) + state = states_by_task_id.get(task.task_id, None) color = State.color_fg(state) fill_color = State.color(state) else: diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 0c98d25503..9184a20420 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -341,7 +341,10 @@ class FileTaskHandler(logging.Handler): ) log_pos = len(logs) messages = "".join([f"*** {x}\n" for x in messages_list]) - end_of_log = ti.try_number != try_number or ti.state not in [State.RUNNING, State.DEFERRED] + end_of_log = ti.try_number != try_number or ti.state not in ( + TaskInstanceState.RUNNING, + TaskInstanceState.DEFERRED, + ) if metadata and "log_pos" in metadata: previous_chars = metadata["log_pos"] logs = logs[previous_chars:] # Cut off previously passed log test as new tail diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py index a4589ebca0..e8c3b9897a 100644 --- a/airflow/utils/log/log_reader.py +++ b/airflow/utils/log/log_reader.py @@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils.helpers import render_log_filename from airflow.utils.log.logging_mixin import ExternalLoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class TaskLogReader: @@ -81,19 +81,26 @@ class TaskLogReader: metadata.pop("max_offset", None) metadata.pop("offset", None) metadata.pop("log_pos", None) + while True: logs, metadata = self.read_log_chunks(ti, current_try_number, metadata) for host, log in logs[0]: yield "\n".join([host or "", log]) + "\n" - if "end_of_log" not in metadata or ( - not metadata["end_of_log"] and ti.state not in [State.RUNNING, State.DEFERRED] - ): - if not logs[0]: - # we did not receive any logs in this loop - # sleeping to conserve resources / limit requests on external services - time.sleep(self.STREAM_LOOP_SLEEP_SECONDS) + try: + end_of_log = bool(metadata["end_of_log"]) + except KeyError: + continue_read = True else: + continue_read = not end_of_log and ti.state not in ( + TaskInstanceState.RUNNING, + TaskInstanceState.DEFERRED, + ) + if not continue_read: break + if not logs[0]: + # we did not receive any logs in this loop + # sleeping to conserve resources / limit requests on external services + time.sleep(self.STREAM_LOOP_SLEEP_SECONDS) @cached_property def log_handler(self): diff --git a/airflow/utils/state.py b/airflow/utils/state.py index f4a8dc1a0a..b6297dfb71 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -98,14 +98,9 @@ class State: finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED]) unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING]) - task_states: tuple[TaskInstanceState | None, ...] = (None,) + tuple(TaskInstanceState) + task_states: tuple[TaskInstanceState | None, ...] = (None, *TaskInstanceState) - dag_states: tuple[DagRunState, ...] = ( - DagRunState.QUEUED, - DagRunState.SUCCESS, - DagRunState.RUNNING, - DagRunState.FAILED, - ) + dag_states: tuple[DagRunState, ...] = tuple(DagRunState) state_color: dict[TaskInstanceState | None, str] = { None: "lightblue", diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 25fc1a28f9..58836beefe 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -85,8 +85,10 @@ def get_instance_with_map(task_instance, session): return get_mapped_summary(task_instance, mapped_instances) -def get_try_count(try_number: int, state: State): - return try_number + 1 if state in [State.DEFERRED, State.UP_FOR_RESCHEDULE] else try_number +def get_try_count(try_number: int, state: TaskInstanceState) -> int: + if state in (TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RESCHEDULE): + return try_number + 1 + return try_number priority: list[None | TaskInstanceState] = [ @@ -426,7 +428,7 @@ def task_instance_link(attr): def state_token(state): - """Returns a formatted string with HTML for a given State.""" + """Returns a formatted string with HTML for a given state.""" color = State.color(state) fg_color = State.color_fg(state) return Markup( @@ -438,7 +440,7 @@ def state_token(state): def state_f(attr): - """Gets 'state' & returns a formatted string with HTML for a given State.""" + """Gets 'state' & returns a formatted string with HTML for a given state.""" state = attr.get("state") return state_token(state) diff --git a/airflow/www/views.py b/airflow/www/views.py index f50df518fe..3da11f5681 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -754,7 +754,7 @@ class Airflow(AirflowBaseView): # find DAGs which have a RUNNING DagRun running_dags = dags_query.join(DagRun, DagModel.dag_id == DagRun.dag_id).filter( - DagRun.state == State.RUNNING + DagRun.state == DagRunState.RUNNING ) # find DAGs for which the latest DagRun is FAILED @@ -765,7 +765,7 @@ class Airflow(AirflowBaseView): ) subq_failed = ( session.query(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) - .filter(DagRun.state == State.FAILED) + .filter(DagRun.state == DagRunState.FAILED) .group_by(DagRun.dag_id) .subquery() ) @@ -1101,7 +1101,7 @@ class Airflow(AirflowBaseView): running_dag_run_query_result = ( session.query(DagRun.dag_id, DagRun.run_id) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .filter(DagRun.state == State.RUNNING, DagModel.is_active) + .filter(DagRun.state == DagRunState.RUNNING, DagModel.is_active) ) running_dag_run_query_result = running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids)) @@ -1125,7 +1125,7 @@ class Airflow(AirflowBaseView): last_dag_run = ( session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("execution_date")) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .filter(DagRun.state != State.RUNNING, DagModel.is_active) + .filter(DagRun.state != DagRunState.RUNNING, DagModel.is_active) .group_by(DagRun.dag_id) ) @@ -1820,7 +1820,7 @@ class Airflow(AirflowBaseView): "Airflow administrator for assistance.".format( "- This task instance already ran and had it's state changed manually " "(e.g. cleared in the UI)<br>" - if ti and ti.state == State.NONE + if ti and ti.state is None else "" ), ) @@ -2119,7 +2119,7 @@ class Airflow(AirflowBaseView): run_type=DagRunType.MANUAL, execution_date=execution_date, data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date), - state=State.QUEUED, + state=DagRunState.QUEUED, conf=run_conf, external_trigger=True, dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), @@ -5337,14 +5337,14 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView): @action_logging def action_set_queued(self, drs: list[DagRun]): """Set state to queued.""" - return self._set_dag_runs_to_active_state(drs, State.QUEUED) + return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED) @action("set_running", "Set state to 'running'", "", single=False) @action_has_dag_edit_access @action_logging def action_set_running(self, drs: list[DagRun]): """Set state to running.""" - return self._set_dag_runs_to_active_state(drs, State.RUNNING) + return self._set_dag_runs_to_active_state(drs, DagRunState.RUNNING) @provide_session def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session: Session = NEW_SESSION): @@ -5353,7 +5353,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView): count = 0 for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 - if state == State.RUNNING: + if state == DagRunState.RUNNING: dr.start_date = timezone.utcnow() dr.state = state session.commit() @@ -5790,7 +5790,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView): @action_logging def action_set_running(self, tis): """Set state to 'running'.""" - self.set_task_instance_state(tis, State.RUNNING) + self.set_task_instance_state(tis, TaskInstanceState.RUNNING) self.update_redirect() return redirect(self.get_redirect()) @@ -5799,7 +5799,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView): @action_logging def action_set_failed(self, tis): """Set state to 'failed'.""" - self.set_task_instance_state(tis, State.FAILED) + self.set_task_instance_state(tis, TaskInstanceState.FAILED) self.update_redirect() return redirect(self.get_redirect()) @@ -5808,7 +5808,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView): @action_logging def action_set_success(self, tis): """Set state to 'success'.""" - self.set_task_instance_state(tis, State.SUCCESS) + self.set_task_instance_state(tis, TaskInstanceState.SUCCESS) self.update_redirect() return redirect(self.get_redirect()) @@ -5817,7 +5817,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView): @action_logging def action_set_retry(self, tis): """Set state to 'up_for_retry'.""" - self.set_task_instance_state(tis, State.UP_FOR_RETRY) + self.set_task_instance_state(tis, TaskInstanceState.UP_FOR_RETRY) self.update_redirect() return redirect(self.get_redirect())