This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b459af3ee0 Replace State usages with strong-typed enums (#31735)
b459af3ee0 is described below

commit b459af3ee0f94cd246e3d401ca3eec18ffd85db0
Author: Tzu-ping Chung <uranu...@gmail.com>
AuthorDate: Sat Jun 17 04:48:57 2023 +0800

    Replace State usages with strong-typed enums (#31735)
    
    Only in the main Airflow code base. There are many more in tests that I 
might tackle some day.
    Additionally, there are some cases where TI state is used for "job" state.  
We may deal with this later by introducing a new type ExecutorState.
---
 airflow/api/common/delete_dag.py                   |  4 +-
 airflow/api/common/mark_tasks.py                   | 30 +++++++++---
 .../endpoints/task_instance_endpoint.py            |  4 +-
 airflow/api_connexion/schemas/enum_schemas.py      |  4 +-
 airflow/dag_processing/processor.py                | 16 +++++--
 airflow/executors/base_executor.py                 | 10 ++--
 airflow/executors/celery_executor.py               |  8 ++--
 airflow/executors/celery_executor_utils.py         |  8 ++--
 airflow/executors/debug_executor.py                | 28 +++++------
 airflow/executors/kubernetes_executor.py           | 24 ++++++----
 airflow/executors/local_executor.py                | 16 +++----
 airflow/executors/sequential_executor.py           |  6 +--
 airflow/jobs/backfill_job_runner.py                |  6 +--
 airflow/jobs/job.py                                |  7 ++-
 airflow/jobs/local_task_job_runner.py              |  6 +--
 airflow/jobs/scheduler_job_runner.py               | 28 +++++------
 airflow/listeners/spec/taskinstance.py             |  6 +--
 airflow/models/dag.py                              | 18 ++++----
 airflow/models/dagrun.py                           | 42 ++++++++---------
 airflow/models/pool.py                             |  8 ++--
 airflow/models/skipmixin.py                        |  4 +-
 airflow/models/taskinstance.py                     | 54 +++++++++++-----------
 airflow/operators/subdag.py                        | 20 ++++----
 airflow/operators/trigger_dagrun.py                |  6 +--
 airflow/sensors/external_task.py                   | 28 +++++------
 airflow/sentry.py                                  |  4 +-
 airflow/ti_deps/dependencies_states.py             | 30 ++++++------
 airflow/ti_deps/deps/dagrun_exists_dep.py          |  4 +-
 airflow/ti_deps/deps/not_in_retry_period_dep.py    |  4 +-
 airflow/ti_deps/deps/not_previously_skipped_dep.py |  4 +-
 airflow/ti_deps/deps/prev_dagrun_dep.py            |  4 +-
 airflow/ti_deps/deps/ready_to_reschedule.py        |  4 +-
 airflow/ti_deps/deps/task_not_running_dep.py       |  4 +-
 airflow/utils/dot_renderer.py                      |  2 +-
 airflow/utils/log/file_task_handler.py             |  5 +-
 airflow/utils/log/log_reader.py                    | 23 +++++----
 airflow/utils/state.py                             |  9 +---
 airflow/www/utils.py                               | 10 ++--
 airflow/www/views.py                               | 26 +++++------
 39 files changed, 280 insertions(+), 244 deletions(-)

diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py
index 45611729ea..1d879a667a 100644
--- a/airflow/api/common/delete_dag.py
+++ b/airflow/api/common/delete_dag.py
@@ -29,7 +29,7 @@ from airflow.models import DagModel, TaskFail
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.utils.db import get_sqla_model_classes
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 log = logging.getLogger(__name__)
 
@@ -50,7 +50,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, 
session: Session =
     running_tis = session.scalar(
         select(models.TaskInstance.state)
         .where(models.TaskInstance.dag_id == dag_id)
-        .where(models.TaskInstance.state == State.RUNNING)
+        .where(models.TaskInstance.state == TaskInstanceState.RUNNING)
         .limit(1)
     )
     if running_tis:
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index b965237bdf..184251b515 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -155,7 +155,7 @@ def set_state(
         for task_instance in tis_altered:
             # The try_number was decremented when setting to up_for_reschedule 
and deferred.
             # Increment it back when changing the state again
-            if task_instance.state in [State.DEFERRED, 
State.UP_FOR_RESCHEDULE]:
+            if task_instance.state in (TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE):
                 task_instance._try_number += 1
             task_instance.set_state(state, session=session)
         session.flush()
@@ -362,7 +362,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: 
DagRunState, session: SA
         select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
     ).scalar_one()
     dag_run.state = state
-    if state == State.RUNNING:
+    if state == DagRunState.RUNNING:
         dag_run.start_date = timezone.utcnow()
         dag_run.end_date = None
     else:
@@ -415,7 +415,13 @@ def set_dag_run_state_to_success(
     # Mark all task instances of the dag run to success.
     for task in dag.tasks:
         task.dag = dag
-    return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS, 
commit=commit, session=session)
+    return set_state(
+        tasks=dag.tasks,
+        run_id=run_id,
+        state=TaskInstanceState.SUCCESS,
+        commit=commit,
+        session=session,
+    )
 
 
 @provide_session
@@ -468,7 +474,9 @@ def set_dag_run_state_to_failed(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
             TaskInstance.task_id.in_(task_ids),
-            TaskInstance.state.in_([State.RUNNING, State.DEFERRED, 
State.UP_FOR_RESCHEDULE]),
+            TaskInstance.state.in_(
+                (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE),
+            ),
         )
     )
 
@@ -487,16 +495,24 @@ def set_dag_run_state_to_failed(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
             TaskInstance.state.not_in(State.finished),
-            TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, 
State.UP_FOR_RESCHEDULE]),
+            TaskInstance.state.not_in(
+                (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE),
+            ),
         )
     )
 
     tis = [ti for ti in tis]
     if commit:
         for ti in tis:
-            ti.set_state(State.SKIPPED)
+            ti.set_state(TaskInstanceState.SKIPPED)
 
-    return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, 
commit=commit, session=session)
+    return tis + set_state(
+        tasks=tasks,
+        run_id=run_id,
+        state=TaskInstanceState.FAILED,
+        commit=commit,
+        session=session,
+    )
 
 
 def __set_dag_run_state_to_running_or_queued(
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 533d97c858..7ccc6b8848 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -50,7 +50,7 @@ from airflow.models.taskinstance import TaskInstance as TI, 
clear_task_instances
 from airflow.security import permissions
 from airflow.utils.airflow_flask_app import get_airflow_app
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState
 
 T = TypeVar("T")
 
@@ -264,7 +264,7 @@ def get_mapped_task_instances(
 def _convert_state(states: Iterable[str] | None) -> list[str | None] | None:
     if not states:
         return None
-    return [State.NONE if s == "none" else s for s in states]
+    return [None if s == "none" else s for s in states]
 
 
 def _apply_array_filter(query: Query, key: ClauseElement, values: 
Iterable[Any] | None) -> Query:
diff --git a/airflow/api_connexion/schemas/enum_schemas.py 
b/airflow/api_connexion/schemas/enum_schemas.py
index 981a3669b1..ba82010783 100644
--- a/airflow/api_connexion/schemas/enum_schemas.py
+++ b/airflow/api_connexion/schemas/enum_schemas.py
@@ -26,7 +26,7 @@ class DagStateField(fields.String):
 
     def __init__(self, **metadata):
         super().__init__(**metadata)
-        self.validators = [validate.OneOf(State.dag_states)] + 
list(self.validators)
+        self.validators = [validate.OneOf(State.dag_states), *self.validators]
 
 
 class TaskInstanceStateField(fields.String):
@@ -34,4 +34,4 @@ class TaskInstanceStateField(fields.String):
 
     def __init__(self, **metadata):
         super().__init__(**metadata)
-        self.validators = [validate.OneOf(State.task_states)] + 
list(self.validators)
+        self.validators = [validate.OneOf(State.task_states), *self.validators]
diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index 9ee865bbf8..d742c5314a 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -56,7 +56,7 @@ from airflow.utils.file import iter_airflow_imports, 
might_contain_dag
 from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, 
set_context
 from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.models.operator import Operator
@@ -432,9 +432,11 @@ class DagFileProcessor(LoggingMixin):
         qry = (
             session.query(TI.task_id, 
func.max(DR.execution_date).label("max_ti"))
             .join(TI.dag_run)
-            .filter(TI.dag_id == dag.dag_id)
-            .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
-            .filter(TI.task_id.in_(dag.task_ids))
+            .filter(
+                TI.dag_id == dag.dag_id,
+                or_(TI.state == TaskInstanceState.SUCCESS, TI.state == 
TaskInstanceState.SKIPPED),
+                TI.task_id.in_(dag.task_ids),
+            )
             .group_by(TI.task_id)
             .subquery("sq")
         )
@@ -500,7 +502,11 @@ class DagFileProcessor(LoggingMixin):
             sla_dates: list[datetime] = [sla.execution_date for sla in slas]
             fetched_tis: list[TI] = (
                 session.query(TI)
-                .filter(TI.state != State.SUCCESS, 
TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+                .filter(
+                    TI.state != TaskInstanceState.SUCCESS,
+                    TI.execution_date.in_(sla_dates),
+                    TI.dag_id == dag.dag_id,
+                )
                 .all()
             )
             blocking_tis: list[TI] = []
diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 9599adabdf..6ae80dced1 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -31,7 +31,7 @@ from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 PARALLELISM: int = conf.getint("core", "PARALLELISM")
 
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
 
     # Event_buffer dict value type
     # Tuple of: state, info
-    EventBufferValueType = Tuple[Optional[str], Any]
+    EventBufferValueType = Tuple[Optional[TaskInstanceState], Any]
 
     # Task tuple to send to be executed
     TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], 
Optional[Any]]
@@ -298,7 +298,7 @@ class BaseExecutor(LoggingMixin):
             self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)
             self.running.add(key)
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> 
None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
         """
         Changes state of the task.
 
@@ -320,7 +320,7 @@ class BaseExecutor(LoggingMixin):
         :param info: Executor information for the task instance
         :param key: Unique key for the task instance
         """
-        self.change_state(key, State.FAILED, info)
+        self.change_state(key, TaskInstanceState.FAILED, info)
 
     def success(self, key: TaskInstanceKey, info=None) -> None:
         """
@@ -329,7 +329,7 @@ class BaseExecutor(LoggingMixin):
         :param info: Executor information for the task instance
         :param key: Unique key for the task instance
         """
-        self.change_state(key, State.SUCCESS, info)
+        self.change_state(key, TaskInstanceState.SUCCESS, info)
 
     def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, 
EventBufferValueType]:
         """
diff --git a/airflow/executors/celery_executor.py 
b/airflow/executors/celery_executor.py
index de59804a04..c9f4ded309 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -38,7 +38,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowTaskTimeout
 from airflow.executors.base_executor import BaseExecutor
 from airflow.stats import Stats
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 log = logging.getLogger(__name__)
 
@@ -150,7 +150,7 @@ class CeleryExecutor(BaseExecutor):
             self.task_publish_retries.pop(key, None)
             if isinstance(result, ExceptionWithTraceback):
                 self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", 
result.exception, result.traceback)
-                self.event_buffer[key] = (State.FAILED, None)
+                self.event_buffer[key] = (TaskInstanceState.FAILED, None)
             elif result is not None:
                 result.backend = cached_celery_backend
                 self.running.add(key)
@@ -159,7 +159,7 @@ class CeleryExecutor(BaseExecutor):
                 # Store the Celery task_id in the event buffer. This will get 
"overwritten" if the task
                 # has another event, but that is fine, because the only other 
events are success/failed at
                 # which point we don't need the ID anymore anyway
-                self.event_buffer[key] = (State.QUEUED, result.task_id)
+                self.event_buffer[key] = (TaskInstanceState.QUEUED, 
result.task_id)
 
                 # If the task runs _really quickly_ we may already have a 
result!
                 self.update_task_state(key, result.state, getattr(result, 
"info", None))
@@ -206,7 +206,7 @@ class CeleryExecutor(BaseExecutor):
             if state:
                 self.update_task_state(key, state, info)
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> 
None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
         super().change_state(key, state, info)
         self.tasks.pop(key, None)
 
diff --git a/airflow/executors/celery_executor_utils.py 
b/airflow/executors/celery_executor_utils.py
index 2c8af4cf91..80460c3c8a 100644
--- a/airflow/executors/celery_executor_utils.py
+++ b/airflow/executors/celery_executor_utils.py
@@ -46,6 +46,7 @@ from airflow.stats import Stats
 from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
+from airflow.utils.state import TaskInstanceState
 from airflow.utils.timeout import timeout
 
 log = logging.getLogger(__name__)
@@ -192,9 +193,10 @@ def send_task_to_executor(
     return key, command, result
 
 
-def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | 
ExceptionWithTraceback, Any]:
-    """
-    Fetch and return the state of the given celery task.
+def fetch_celery_task_state(
+    async_result: AsyncResult,
+) -> tuple[str, TaskInstanceState | ExceptionWithTraceback, Any]:
+    """Fetch and return the state of the given celery task.
 
     The scope of this function is global so that it can be called by 
subprocesses in the pool.
 
diff --git a/airflow/executors/debug_executor.py 
b/airflow/executors/debug_executor.py
index ca23b09a67..8a46d6cda0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -29,7 +29,7 @@ import time
 from typing import TYPE_CHECKING, Any
 
 from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance
@@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor):
         while self.tasks_to_run:
             ti = self.tasks_to_run.pop(0)
             if self.fail_fast and not task_succeeded:
-                self.log.info("Setting %s to %s", ti.key, 
State.UPSTREAM_FAILED)
-                ti.set_state(State.UPSTREAM_FAILED)
-                self.change_state(ti.key, State.UPSTREAM_FAILED)
+                self.log.info("Setting %s to %s", ti.key, 
TaskInstanceState.UPSTREAM_FAILED)
+                ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+                self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
                 continue
 
             if self._terminated.is_set():
-                self.log.info("Executor is terminated! Stopping %s to %s", 
ti.key, State.FAILED)
-                ti.set_state(State.FAILED)
-                self.change_state(ti.key, State.FAILED)
+                self.log.info("Executor is terminated! Stopping %s to %s", 
ti.key, TaskInstanceState.FAILED)
+                ti.set_state(TaskInstanceState.FAILED)
+                self.change_state(ti.key, TaskInstanceState.FAILED)
                 continue
 
             task_succeeded = self._run_task(ti)
@@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor):
         try:
             params = self.tasks_params.pop(ti.key, {})
             ti.run(job_id=ti.job_id, **params)
-            self.change_state(key, State.SUCCESS)
+            self.change_state(key, TaskInstanceState.SUCCESS)
             return True
         except Exception as e:
-            ti.set_state(State.FAILED)
-            self.change_state(key, State.FAILED)
+            ti.set_state(TaskInstanceState.FAILED)
+            self.change_state(key, TaskInstanceState.FAILED)
             self.log.exception("Failed to execute task: %s.", str(e))
             return False
 
@@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor):
     def end(self) -> None:
         """Set states of queued tasks to UPSTREAM_FAILED marking them as not 
executed."""
         for ti in self.tasks_to_run:
-            self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
-            ti.set_state(State.UPSTREAM_FAILED)
-            self.change_state(ti.key, State.UPSTREAM_FAILED)
+            self.log.info("Setting %s to %s", ti.key, 
TaskInstanceState.UPSTREAM_FAILED)
+            ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+            self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
 
     def terminate(self) -> None:
         self._terminated.set()
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> 
None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
         self.log.debug("Popping %s from executor task queue.", key)
         self.running.remove(key)
         self.event_buffer[key] = state, info
diff --git a/airflow/executors/kubernetes_executor.py 
b/airflow/executors/kubernetes_executor.py
index 8707bf9699..092a84f470 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -53,7 +53,7 @@ from airflow.kubernetes.pod_generator import PodGenerator
 from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
@@ -228,12 +228,16 @@ class KubernetesJobWatcher(multiprocessing.Process, 
LoggingMixin):
             # since kube server have received request to delete pod set TI 
state failed
             if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
                 self.log.info("Event: Failed to start pod %s, annotations: 
%s", pod_name, annotations_string)
-                self.watcher_queue.put((pod_name, namespace, State.FAILED, 
annotations, resource_version))
+                self.watcher_queue.put(
+                    (pod_name, namespace, TaskInstanceState.FAILED, 
annotations, resource_version),
+                )
             else:
                 self.log.debug("Event: %s Pending, annotations: %s", pod_name, 
annotations_string)
         elif status == "Failed":
             self.log.error("Event: %s Failed, annotations: %s", pod_name, 
annotations_string)
-            self.watcher_queue.put((pod_name, namespace, State.FAILED, 
annotations, resource_version))
+            self.watcher_queue.put(
+                (pod_name, namespace, TaskInstanceState.FAILED, annotations, 
resource_version),
+            )
         elif status == "Succeeded":
             # We get multiple events once the pod hits a terminal state, and 
we only want to
             # send it along to the scheduler once.
@@ -261,7 +265,9 @@ class KubernetesJobWatcher(multiprocessing.Process, 
LoggingMixin):
                     pod_name,
                     annotations_string,
                 )
-                self.watcher_queue.put((pod_name, namespace, State.FAILED, 
annotations, resource_version))
+                self.watcher_queue.put(
+                    (pod_name, namespace, TaskInstanceState.FAILED, 
annotations, resource_version),
+                )
             else:
                 self.log.info("Event: %s is Running, annotations: %s", 
pod_name, annotations_string)
         else:
@@ -700,7 +706,7 @@ class KubernetesExecutor(BaseExecutor):
                     last_resource_version[namespace] = resource_version
                     self.log.info("Changing state of %s to %s", results, state)
                     try:
-                        self._change_state(key, state, pod_name, namespace)
+                        self._change_state(key, TaskInstanceState(state), 
pod_name, namespace)
                     except Exception as e:
                         self.log.exception(
                             "Exception: %s when attempting to change state of 
%s to %s, re-queueing.",
@@ -767,7 +773,7 @@ class KubernetesExecutor(BaseExecutor):
     def _change_state(
         self,
         key: TaskInstanceKey,
-        state: str | None,
+        state: TaskInstanceState | None,
         pod_name: str,
         namespace: str,
         session: Session = NEW_SESSION,
@@ -776,12 +782,12 @@ class KubernetesExecutor(BaseExecutor):
             assert self.kube_scheduler
         from airflow.models.taskinstance import TaskInstance
 
-        if state == State.RUNNING:
+        if state == TaskInstanceState.RUNNING:
             self.event_buffer[key] = state, None
             return
 
         if self.kube_config.delete_worker_pods:
-            if state != State.FAILED or 
self.kube_config.delete_worker_pods_on_failure:
+            if state != TaskInstanceState.FAILED or 
self.kube_config.delete_worker_pods_on_failure:
                 self.kube_scheduler.delete_pod(pod_name=pod_name, 
namespace=namespace)
                 self.log.info("Deleted pod: %s in namespace %s", str(key), 
str(namespace))
         else:
@@ -1011,7 +1017,7 @@ class KubernetesExecutor(BaseExecutor):
                         "Changing state of %s to %s : resource_version=%d", 
results, state, resource_version
                     )
                     try:
-                        self._change_state(key, state, pod_name, namespace)
+                        self._change_state(key, TaskInstanceState(state), 
pod_name, namespace)
                     except Exception as e:
                         self.log.exception(
                             "Ignoring exception: %s when attempting to change 
state of %s to %s.",
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index 715bdb42ae..550a6519e1 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -39,7 +39,7 @@ from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import PARALLELISM, BaseExecutor
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
@@ -94,20 +94,20 @@ class LocalWorkerBase(Process, LoggingMixin):
         # Remove the command since the worker is done executing the task
         setproctitle("airflow worker -- LocalExecutor")
 
-    def _execute_work_in_subprocess(self, command: CommandType) -> str:
+    def _execute_work_in_subprocess(self, command: CommandType) -> 
TaskInstanceState:
         try:
             subprocess.check_call(command, close_fds=True)
-            return State.SUCCESS
+            return TaskInstanceState.SUCCESS
         except subprocess.CalledProcessError as e:
             self.log.error("Failed to execute task %s.", str(e))
-            return State.FAILED
+            return TaskInstanceState.FAILED
 
-    def _execute_work_in_fork(self, command: CommandType) -> str:
+    def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState:
         pid = os.fork()
         if pid:
             # In parent, wait for the child
             pid, ret = os.waitpid(pid, 0)
-            return State.SUCCESS if ret == 0 else State.FAILED
+            return TaskInstanceState.SUCCESS if ret == 0 else 
TaskInstanceState.FAILED
 
         from airflow.sentry import Sentry
 
@@ -130,10 +130,10 @@ class LocalWorkerBase(Process, LoggingMixin):
 
             args.func(args)
             ret = 0
-            return State.SUCCESS
+            return TaskInstanceState.SUCCESS
         except Exception as e:
             self.log.exception("Failed to execute task %s.", e)
-            return State.FAILED
+            return TaskInstanceState.FAILED
         finally:
             Sentry.flush()
             logging.shutdown()
diff --git a/airflow/executors/sequential_executor.py 
b/airflow/executors/sequential_executor.py
index 28f88c6b87..2715edad6e 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -28,7 +28,7 @@ import subprocess
 from typing import TYPE_CHECKING, Any
 
 from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
@@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor):
 
             try:
                 subprocess.check_call(command, close_fds=True)
-                self.change_state(key, State.SUCCESS)
+                self.change_state(key, TaskInstanceState.SUCCESS)
             except subprocess.CalledProcessError as e:
-                self.change_state(key, State.FAILED)
+                self.change_state(key, TaskInstanceState.FAILED)
                 self.log.error("Failed to execute task %s.", str(e))
 
         self.commands_to_run = []
diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index 40c31c4451..a0efcb17af 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -68,7 +68,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
 
     job_type = "BackfillJob"
 
-    STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
+    STATES_COUNT_AS_RUNNING = (TaskInstanceState.RUNNING, 
TaskInstanceState.QUEUED)
 
     @attr.define
     class _DagRunTaskStatus:
@@ -219,7 +219,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
             # is changed externally, e.g. by clearing tasks from the ui. We 
need to cover
             # for that as otherwise those tasks would fall outside the scope of
             # the backfill suddenly.
-            elif ti.state == State.NONE:
+            elif ti.state is None:
                 self.log.warning(
                     "FIXME: task instance %s state was set to none externally 
or "
                     "reaching concurrency limits. Re-adding task to queue.",
@@ -1004,7 +1004,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
             ).all()
 
             for ti in reset_tis:
-                ti.state = State.NONE
+                ti.state = None
                 session.merge(ti)
 
             return result + reset_tis
diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py
index fe75508f63..903cfb9aca 100644
--- a/airflow/jobs/job.py
+++ b/airflow/jobs/job.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from functools import cached_property
 from time import sleep
-from typing import Callable, NoReturn
+from typing import TYPE_CHECKING, Callable, NoReturn
 
 from sqlalchemy import Column, Index, Integer, String, case, select
 from sqlalchemy.exc import OperationalError
@@ -42,6 +42,9 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import BaseExecutor
+
 
 def _resolve_dagrun_model():
     from airflow.models.dagrun import DagRun
@@ -117,7 +120,7 @@ class Job(Base, LoggingMixin):
         super().__init__(**kwargs)
 
     @cached_property
-    def executor(self):
+    def executor(self) -> BaseExecutor:
         return ExecutorLoader.get_default_executor()
 
     def is_alive(self, grace_multiplier=2.1):
diff --git a/airflow/jobs/local_task_job_runner.py 
b/airflow/jobs/local_task_job_runner.py
index 9f6a4b55e8..fd234e4150 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -35,7 +35,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import IS_WINDOWS
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 SIGSEGV_MESSAGE = """
 ******************************************* Received SIGSEGV 
*******************************************
@@ -243,7 +243,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | 
JobPydantic"], LoggingMixin):
         self.task_instance.refresh_from_db()
         ti = self.task_instance
 
-        if ti.state == State.RUNNING:
+        if ti.state == TaskInstanceState.RUNNING:
             fqdn = get_hostname()
             same_hostname = fqdn == ti.hostname
             if not same_hostname:
@@ -273,7 +273,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | 
JobPydantic"], LoggingMixin):
                 )
                 raise AirflowException("PID of job runner does not match")
         elif self.task_runner.return_code() is None and 
hasattr(self.task_runner, "process"):
-            if ti.state == State.SKIPPED:
+            if ti.state == TaskInstanceState.SKIPPED:
                 # A DagRun timeout will cause tasks to be externally marked as 
skipped.
                 dagrun = ti.get_dagrun(session=session)
                 execution_time = (dagrun.end_date or timezone.utcnow()) - 
dagrun.start_date
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index a85a0925fb..16764b1980 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -610,7 +610,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
             )
 
             for ti in executable_tis:
-                ti.emit_state_change_metric(State.QUEUED)
+                ti.emit_state_change_metric(TaskInstanceState.QUEUED)
 
         for ti in executable_tis:
             make_transient(ti)
@@ -626,7 +626,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
         # actually enqueue them
         for ti in task_instances:
             if ti.dag_run.state in State.finished:
-                ti.set_state(State.NONE, session=session)
+                ti.set_state(None, session=session)
                 continue
             command = ti.command_as_list(
                 local=True,
@@ -682,13 +682,13 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
 
         # Report execution
         for ti_key, value in event_buffer.items():
-            state: str
+            state: TaskInstanceState | None
             state, _ = value
             # We create map (dag_id, task_id, execution_date) -> in-memory 
try_number
             ti_primary_key_to_try_number_map[ti_key.primary] = 
ti_key.try_number
 
             self.log.info("Received executor event with state %s for task 
instance %s", state, ti_key)
-            if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+            if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, 
TaskInstanceState.QUEUED):
                 tis_with_right_state.append(ti_key)
 
         # Return if no finished tasks
@@ -712,7 +712,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
             buffer_key = ti.key.with_try_number(try_number)
             state, info = event_buffer.pop(buffer_key)
 
-            if state == State.QUEUED:
+            if state == TaskInstanceState.QUEUED:
                 ti.external_executor_id = info
                 self.log.info("Setting external_id for %s to %s", ti, info)
                 continue
@@ -1535,7 +1535,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
 
         tasks_stuck_in_queued = session.scalars(
             select(TI).where(
-                TI.state == State.QUEUED,
+                TI.state == TaskInstanceState.QUEUED,
                 TI.queued_dttm < (timezone.utcnow() - 
timedelta(seconds=self._task_queued_timeout)),
                 TI.queued_by_job_id == self.job.id,
             )
@@ -1602,20 +1602,19 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                         self.log.info("Marked %d SchedulerJob instances as 
failed", num_failed)
                         Stats.incr(self.__class__.__name__.lower() + "_end", 
num_failed)
 
-                    resettable_states = [TaskInstanceState.QUEUED, 
TaskInstanceState.RUNNING]
                     query = (
                         select(TI)
-                        .where(TI.state.in_(resettable_states))
+                        .where(TI.state.in_((TaskInstanceState.QUEUED, 
TaskInstanceState.RUNNING)))
                         # outerjoin is because we didn't use to have 
queued_by_job
                         # set, so we need to pick up anything pre upgrade. 
This (and the
                         # "or queued_by_job_id IS NONE") can go as soon as 
scheduler HA is
                         # released.
                         .outerjoin(TI.queued_by_job)
-                        .where(or_(TI.queued_by_job_id.is_(None), Job.state != 
State.RUNNING))
+                        .where(or_(TI.queued_by_job_id.is_(None), Job.state != 
TaskInstanceState.RUNNING))
                         .join(TI.dag_run)
                         .where(
                             DagRun.run_type != DagRunType.BACKFILL_JOB,
-                            DagRun.state == State.RUNNING,
+                            DagRun.state == DagRunState.RUNNING,
                         )
                         .options(load_only(TI.dag_id, TI.task_id, TI.run_id))
                     )
@@ -1630,7 +1629,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                     reset_tis_message = []
                     for ti in to_reset:
                         reset_tis_message.append(repr(ti))
-                        ti.state = State.NONE
+                        ti.state = None
                         ti.queued_by_job_id = None
 
                     for ti in set(tis_to_reset_or_adopt) - set(to_reset):
@@ -1697,12 +1696,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                     .join(Job, TI.job_id == Job.id)
                     .join(DM, TI.dag_id == DM.dag_id)
                     .where(TI.state == TaskInstanceState.RUNNING)
-                    .where(
-                        or_(
-                            Job.state != State.RUNNING,
-                            Job.latest_heartbeat < limit_dttm,
-                        )
-                    )
+                    .where(or_(Job.state != State.RUNNING, 
Job.latest_heartbeat < limit_dttm))
                     .where(Job.job_type == "LocalTaskJob")
                     .where(TI.queued_by_job_id == self.job.id)
                 )
diff --git a/airflow/listeners/spec/taskinstance.py 
b/airflow/listeners/spec/taskinstance.py
index 78de8a5f62..56b4cb7322 100644
--- a/airflow/listeners/spec/taskinstance.py
+++ b/airflow/listeners/spec/taskinstance.py
@@ -34,18 +34,18 @@ hookspec = HookspecMarker("airflow")
 def on_task_instance_running(
     previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
 ):
-    """Called when task state changes to RUNNING. Previous_state can be 
State.NONE."""
+    """Called when task state changes to RUNNING. Previous state can be 
None."""
 
 
 @hookspec
 def on_task_instance_success(
     previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
 ):
-    """Called when task state changes to SUCCESS. Previous_state can be 
State.NONE."""
+    """Called when task state changes to SUCCESS. Previous state can be 
None."""
 
 
 @hookspec
 def on_task_instance_failed(
     previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
 ):
-    """Called when task state changes to FAIL. Previous_state can be 
State.NONE."""
+    """Called when task state changes to FAIL. Previous state can be None."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b7361d6f3b..57b15e5f32 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -103,7 +103,7 @@ from airflow.utils.helpers import at_most_one, exactly_one, 
validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, 
tuple_in_condition, with_row_locks
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -1281,7 +1281,7 @@ class DAG(LoggingMixin):
         TI = TaskInstance
         qry = session.query(func.count(TI.task_id)).filter(
             TI.dag_id == self.dag_id,
-            TI.state == State.RUNNING,
+            TI.state == TaskInstanceState.RUNNING,
         )
         return qry.scalar() >= self.max_active_tasks
 
@@ -1368,7 +1368,7 @@ class DAG(LoggingMixin):
 
         :return: List of execution dates
         """
-        runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
+        runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING)
 
         active_dates = []
         for run in runs:
@@ -1388,9 +1388,9 @@ class DAG(LoggingMixin):
         # .count() is inefficient
         query = session.query(func.count()).filter(DagRun.dag_id == 
self.dag_id)
         if only_running:
-            query = query.filter(DagRun.state == State.RUNNING)
+            query = query.filter(DagRun.state == DagRunState.RUNNING)
         else:
-            query = query.filter(DagRun.state.in_({State.RUNNING, 
State.QUEUED}))
+            query = query.filter(DagRun.state.in_({DagRunState.RUNNING, 
DagRunState.QUEUED}))
 
         if external_trigger is not None:
             query = query.filter(
@@ -2077,7 +2077,7 @@ class DAG(LoggingMixin):
     @provide_session
     def set_dag_runs_state(
         self,
-        state: str = State.RUNNING,
+        state: DagRunState = DagRunState.RUNNING,
         session: Session = NEW_SESSION,
         start_date: datetime | None = None,
         end_date: datetime | None = None,
@@ -2160,10 +2160,10 @@ class DAG(LoggingMixin):
 
         state = []
         if only_failed:
-            state += [State.FAILED, State.UPSTREAM_FAILED]
+            state += [TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED]
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing 
behaviour
-            state += [State.RUNNING]
+            state += [TaskInstanceState.RUNNING]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
@@ -2694,7 +2694,7 @@ class DAG(LoggingMixin):
         # Instead of starting a scheduler, we run the minimal loop possible to 
check
         # for task readiness and dependency management. This is notably faster
         # than creating a BackfillJob and allows us to surface logs to the user
-        while dr.state == State.RUNNING:
+        while dr.state == DagRunState.RUNNING:
             schedulable_tis, _ = dr.update_state(session=session)
             try:
                 for ti in schedulable_tis:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 45989be15a..bd94977599 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -111,7 +111,7 @@ class DagRun(Base, LoggingMixin):
     execution_date = Column(UtcDateTime, default=timezone.utcnow, 
nullable=False)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
-    _state = Column("state", String(50), default=State.QUEUED)
+    _state = Column("state", String(50), default=DagRunState.QUEUED)
     run_id = Column(StringID(), nullable=False)
     creating_job_id = Column(Integer)
     external_trigger = Column(Boolean, default=True)
@@ -218,7 +218,7 @@ class DagRun(Base, LoggingMixin):
         if state is not None:
             self.state = state
         if queued_at is NOTSET:
-            self.queued_at = timezone.utcnow() if state == State.QUEUED else 
None
+            self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED 
else None
         else:
             self.queued_at = queued_at
         self.run_type = run_type
@@ -256,7 +256,7 @@ class DagRun(Base, LoggingMixin):
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished 
else None
-            if state == State.QUEUED:
+            if state == DagRunState.QUEUED:
                 self.queued_at = timezone.utcnow()
 
     @declared_attr
@@ -289,9 +289,9 @@ class DagRun(Base, LoggingMixin):
             # because SQLAlchemy doesn't accept a set here.
             query = query.filter(cls.dag_id.in_(set(dag_ids)))
         if only_running:
-            query = query.filter(cls.state == State.RUNNING)
+            query = query.filter(cls.state == DagRunState.RUNNING)
         else:
-            query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
+            query = query.filter(cls.state.in_((DagRunState.RUNNING, 
DagRunState.QUEUED)))
         query = query.group_by(cls.dag_id)
         return {dag_id: count for dag_id, count in query.all()}
 
@@ -323,7 +323,7 @@ class DagRun(Base, LoggingMixin):
             .join(DagModel, DagModel.dag_id == cls.dag_id)
             .filter(DagModel.is_paused == false(), DagModel.is_active == 
true())
         )
-        if state == State.QUEUED:
+        if state == DagRunState.QUEUED:
             # For dag runs in the queued state, we check if they have reached 
the max_active_runs limit
             # and if so we drop them
             running_drs = (
@@ -462,7 +462,7 @@ class DagRun(Base, LoggingMixin):
                 tis = tis.filter(TI.state == state)
             else:
                 # this is required to deal with NULL values
-                if State.NONE in state:
+                if None in state:
                     if all(x is None for x in state):
                         tis = tis.filter(TI.state.is_(None))
                     else:
@@ -734,9 +734,9 @@ class DagRun(Base, LoggingMixin):
                 try:
                     ti.task = dag.get_task(ti.task_id)
                 except TaskNotFound:
-                    if ti.state != State.REMOVED:
+                    if ti.state != TaskInstanceState.REMOVED:
                         self.log.error("Failed to get task for ti %s. Marking 
it as removed.", ti)
-                        ti.state = State.REMOVED
+                        ti.state = TaskInstanceState.REMOVED
                         session.flush()
                 else:
                     yield ti
@@ -945,7 +945,7 @@ class DagRun(Base, LoggingMixin):
             self.log.warning("Failed to record first_task_scheduling_delay 
metric:", exc_info=True)
 
     def _emit_duration_stats_for_finished_state(self):
-        if self.state == State.RUNNING:
+        if self.state == DagRunState.RUNNING:
             return
         if self.start_date is None:
             self.log.warning("Failed to record duration of %s: start_date is 
not set.", self)
@@ -1017,22 +1017,22 @@ class DagRun(Base, LoggingMixin):
             try:
                 task = dag.get_task(ti.task_id)
 
-                should_restore_task = (task is not None) and ti.state == 
State.REMOVED
+                should_restore_task = (task is not None) and ti.state == 
TaskInstanceState.REMOVED
                 if should_restore_task:
                     self.log.info("Restoring task '%s' which was previously 
removed from DAG '%s'", ti, dag)
                     Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 
tags=self.stats_tags)
                     # Same metric with tagging
                     Stats.incr("task_restored_to_dag", 
tags={**self.stats_tags, "dag_id": dag.dag_id})
-                    ti.state = State.NONE
+                    ti.state = None
             except AirflowException:
-                if ti.state == State.REMOVED:
+                if ti.state == TaskInstanceState.REMOVED:
                     pass  # ti has already been removed, just ignore it
-                elif self.state != State.RUNNING and not dag.partial:
+                elif self.state != DagRunState.RUNNING and not dag.partial:
                     self.log.warning("Failed to get task '%s' for dag '%s'. 
Marking it as removed.", ti, dag)
                     Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 
tags=self.stats_tags)
                     # Same metric with tagging
                     Stats.incr("task_removed_from_dag", 
tags={**self.stats_tags, "dag_id": dag.dag_id})
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
                 continue
 
             try:
@@ -1049,7 +1049,7 @@ class DagRun(Base, LoggingMixin):
                         self.log.debug(
                             "Removing the unmapped TI '%s' as the mapping 
can't be resolved yet", ti
                         )
-                        ti.state = State.REMOVED
+                        ti.state = TaskInstanceState.REMOVED
                     continue
                 # Upstreams finished, check there aren't any extras
                 if ti.map_index >= total_length:
@@ -1058,7 +1058,7 @@ class DagRun(Base, LoggingMixin):
                         ti,
                         total_length,
                     )
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
             else:
                 # Check if the number of mapped literals has changed, and we 
need to mark this TI as removed.
                 if ti.map_index >= num_mapped_tis:
@@ -1067,10 +1067,10 @@ class DagRun(Base, LoggingMixin):
                         ti,
                         num_mapped_tis,
                     )
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
                 elif ti.map_index < 0:
                     self.log.debug("Removing the unmapped TI '%s' as the 
mapping can now be performed", ti)
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
 
         return task_ids
 
@@ -1342,7 +1342,7 @@ class DagRun(Base, LoggingMixin):
                         TI.run_id == self.run_id,
                         tuple_in_condition((TI.task_id, TI.map_index), 
schedulable_ti_ids_chunk),
                     )
-                    .update({TI.state: State.SCHEDULED}, 
synchronize_session=False)
+                    .update({TI.state: TaskInstanceState.SCHEDULED}, 
synchronize_session=False)
                 )
 
         # Tasks using EmptyOperator should not be executed, mark them as 
success
@@ -1358,7 +1358,7 @@ class DagRun(Base, LoggingMixin):
                     )
                     .update(
                         {
-                            TI.state: State.SUCCESS,
+                            TI.state: TaskInstanceState.SUCCESS,
                             TI.start_date: timezone.utcnow(),
                             TI.end_date: timezone.utcnow(),
                             TI.duration: 0,
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index d1766d4a0a..cfad662691 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -28,7 +28,7 @@ from airflow.ti_deps.dependencies_states import 
EXECUTION_STATES
 from airflow.typing_compat import TypedDict
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import nowait, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class PoolStats(TypedDict):
@@ -245,7 +245,7 @@ class Pool(Base):
         return int(
             session.query(func.sum(TaskInstance.pool_slots))
             .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.RUNNING)
+            .filter(TaskInstance.state == TaskInstanceState.RUNNING)
             .scalar()
             or 0
         )
@@ -263,7 +263,7 @@ class Pool(Base):
         return int(
             session.query(func.sum(TaskInstance.pool_slots))
             .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.QUEUED)
+            .filter(TaskInstance.state == TaskInstanceState.QUEUED)
             .scalar()
             or 0
         )
@@ -281,7 +281,7 @@ class Pool(Base):
         return int(
             session.query(func.sum(TaskInstance.pool_slots))
             .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.SCHEDULED)
+            .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
             .scalar()
             or 0
         )
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d75a4a0e4d..73a733bd03 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -26,7 +26,7 @@ from airflow.serialization.pydantic.dag_run import 
DagRunPydantic
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from pendulum import DateTime
@@ -72,7 +72,7 @@ class SkipMixin(LoggingMixin):
             TaskInstance.task_id.in_(d.task_id for d in tasks),
         ).update(
             {
-                TaskInstance.state: State.SKIPPED,
+                TaskInstance.state: TaskInstanceState.SKIPPED,
                 TaskInstance.start_date: now,
                 TaskInstance.end_date: now,
             },
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 28dc168ec3..829c630786 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -591,7 +591,7 @@ class TaskInstance(Base, LoggingMixin):
         database, in all other cases this will be incremented.
         """
         # This is designed so that task logs end up in the right file.
-        if self.state == State.RUNNING:
+        if self.state == TaskInstanceState.RUNNING:
             return self._try_number
         return self._try_number + 1
 
@@ -804,7 +804,7 @@ class TaskInstance(Base, LoggingMixin):
         :param session: SQLAlchemy ORM Session
         """
         self.log.error("Recording the task instance as FAILED")
-        self.state = State.FAILED
+        self.state = TaskInstanceState.FAILED
         session.merge(self)
         session.commit()
 
@@ -931,7 +931,7 @@ class TaskInstance(Base, LoggingMixin):
         self.log.debug("Setting task state for %s to %s", self, state)
         self.state = state
         self.start_date = self.start_date or current_time
-        if self.state in State.finished or self.state == State.UP_FOR_RETRY:
+        if self.state in State.finished or self.state == 
TaskInstanceState.UP_FOR_RETRY:
             self.end_date = self.end_date or current_time
             self.duration = (self.end_date - self.start_date).total_seconds()
         session.merge(self)
@@ -944,7 +944,7 @@ class TaskInstance(Base, LoggingMixin):
         has elapsed.
         """
         # is the task still in the retry waiting period?
-        return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
+        return self.state == TaskInstanceState.UP_FOR_RETRY and not 
self.ready_for_retry()
 
     @provide_session
     def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
@@ -967,7 +967,7 @@ class TaskInstance(Base, LoggingMixin):
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id.in_(task.downstream_task_ids),
             TaskInstance.run_id == self.run_id,
-            TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
+            TaskInstance.state.in_((TaskInstanceState.SKIPPED, 
TaskInstanceState.SUCCESS)),
         )
         count = ti[0][0]
         return count == len(task.downstream_task_ids)
@@ -1204,7 +1204,7 @@ class TaskInstance(Base, LoggingMixin):
         Checks on whether the task instance is in the right state and timeframe
         to be retried.
         """
-        return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() 
< timezone.utcnow()
+        return self.state == TaskInstanceState.UP_FOR_RETRY and 
self.next_retry_datetime() < timezone.utcnow()
 
     @provide_session
     def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
@@ -1273,7 +1273,7 @@ class TaskInstance(Base, LoggingMixin):
         self.hostname = get_hostname()
         self.pid = None
 
-        if not ignore_all_deps and not ignore_ti_state and self.state == 
State.SUCCESS:
+        if not ignore_all_deps and not ignore_ti_state and self.state == 
TaskInstanceState.SUCCESS:
             Stats.incr("previously_succeeded", tags=self.stats_tags)
 
         if not mark_success:
@@ -1301,7 +1301,7 @@ class TaskInstance(Base, LoggingMixin):
             # start date that is recorded in task_reschedule table
             # If the task continues after being deferred (next_method is set), 
use the original start_date
             self.start_date = self.start_date if self.next_method else 
timezone.utcnow()
-            if self.state == State.UP_FOR_RESCHEDULE:
+            if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                 task_reschedule: TR = TR.query_for_task_instance(self, 
session=session).first()
                 if task_reschedule:
                     self.start_date = task_reschedule.start_date
@@ -1319,7 +1319,7 @@ class TaskInstance(Base, LoggingMixin):
                 description="requeueable deps",
             )
             if not self.are_dependencies_met(dep_context=dep_context, 
session=session, verbose=True):
-                self.state = State.NONE
+                self.state = None
                 self.log.warning(
                     "Rescheduling due to concurrency limits reached "
                     "at task runtime. Attempt %s of "
@@ -1339,10 +1339,10 @@ class TaskInstance(Base, LoggingMixin):
         self._try_number += 1
 
         if not test_mode:
-            session.add(Log(State.RUNNING, self))
+            session.add(Log(TaskInstanceState.RUNNING, self))
 
-        self.state = State.RUNNING
-        self.emit_state_change_metric(State.RUNNING)
+        self.state = TaskInstanceState.RUNNING
+        self.emit_state_change_metric(TaskInstanceState.RUNNING)
         self.external_executor_id = external_executor_id
         self.end_date = None
         if not test_mode:
@@ -1398,7 +1398,7 @@ class TaskInstance(Base, LoggingMixin):
             return
 
         # switch on state and deduce which metric to send
-        if new_state == State.RUNNING:
+        if new_state == TaskInstanceState.RUNNING:
             metric_name = "queued_duration"
             if self.queued_dttm is None:
                 # this should not really happen except in tests or rare cases,
@@ -1410,7 +1410,7 @@ class TaskInstance(Base, LoggingMixin):
                 )
                 return
             timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
-        elif new_state == State.QUEUED:
+        elif new_state == TaskInstanceState.QUEUED:
             metric_name = "scheduled_duration"
             if self.start_date is None:
                 # same comment as above
@@ -1497,7 +1497,7 @@ class TaskInstance(Base, LoggingMixin):
                 self._execute_task_with_callbacks(context, test_mode)
             if not test_mode:
                 self.refresh_from_db(lock_for_update=True, session=session)
-            self.state = State.SUCCESS
+            self.state = TaskInstanceState.SUCCESS
         except TaskDeferred as defer:
             # The task has signalled it wants to defer execution based on
             # a trigger.
@@ -1521,7 +1521,7 @@ class TaskInstance(Base, LoggingMixin):
                 self.log.info(e)
             if not test_mode:
                 self.refresh_from_db(lock_for_update=True, session=session)
-            self.state = State.SKIPPED
+            self.state = TaskInstanceState.SKIPPED
         except AirflowRescheduleException as reschedule_exception:
             self._handle_reschedule(actual_start_date, reschedule_exception, 
test_mode, session=session)
             session.commit()
@@ -1744,7 +1744,7 @@ class TaskInstance(Base, LoggingMixin):
         # Then, update ourselves so it matches the deferral request
         # Keep an eye on the logic in 
`check_and_change_state_before_execution()`
         # depending on self.next_method semantics
-        self.state = State.DEFERRED
+        self.state = TaskInstanceState.DEFERRED
         self.trigger_id = trigger_row.id
         self.next_method = defer.method_name
         self.next_kwargs = defer.kwargs or {}
@@ -1863,7 +1863,7 @@ class TaskInstance(Base, LoggingMixin):
         )
 
         # set state
-        self.state = State.UP_FOR_RESCHEDULE
+        self.state = TaskInstanceState.UP_FOR_RESCHEDULE
 
         # Decrement try_number so subsequent runs will use the same try number 
and write
         # to same log file.
@@ -1928,7 +1928,7 @@ class TaskInstance(Base, LoggingMixin):
         Stats.incr("ti_failures", tags=self.stats_tags)
 
         if not test_mode:
-            session.add(Log(State.FAILED, self))
+            session.add(Log(TaskInstanceState.FAILED, self))
 
             # Log failure duration
             session.add(TaskFail(ti=self))
@@ -1962,7 +1962,7 @@ class TaskInstance(Base, LoggingMixin):
             self.log.error("Unable to unmap task to determine if we need to 
send an alert email")
 
         if force_fail or not self.is_eligible_to_retry():
-            self.state = State.FAILED
+            self.state = TaskInstanceState.FAILED
             email_for_state = operator.attrgetter("email_on_failure")
             callbacks = task.on_failure_callback if task else None
             callback_type = "on_failure"
@@ -1971,10 +1971,10 @@ class TaskInstance(Base, LoggingMixin):
                 tis = self.get_dagrun(session).get_task_instances()
                 stop_all_tasks_in_dag(tis, session, self.task_id)
         else:
-            if self.state == State.QUEUED:
+            if self.state == TaskInstanceState.QUEUED:
                 # We increase the try_number so as to fail the task if it 
fails to start after sometime
                 self._try_number += 1
-            self.state = State.UP_FOR_RETRY
+            self.state = TaskInstanceState.UP_FOR_RETRY
             email_for_state = operator.attrgetter("email_on_retry")
             callbacks = task.on_retry_callback if task else None
             callback_type = "on_retry"
@@ -1995,7 +1995,7 @@ class TaskInstance(Base, LoggingMixin):
 
     def is_eligible_to_retry(self):
         """Is task instance is eligible for retry."""
-        if self.state == State.RESTARTING:
+        if self.state == TaskInstanceState.RESTARTING:
             # If a task is cleared when running, it goes into RESTARTING state 
and is always
             # eligible for retry
             return True
@@ -2336,8 +2336,8 @@ class TaskInstance(Base, LoggingMixin):
             'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
         )
 
-        # This function is called after changing the state from State.RUNNING,
-        # so we need to subtract 1 from self.try_number here.
+        # This function is called after changing the state from RUNNING, so we
+        # need to subtract 1 from try_number here.
         current_try_number = self.try_number - 1
         additional_context: dict[str, Any] = {
             "exception": exception,
@@ -2563,7 +2563,7 @@ class TaskInstance(Base, LoggingMixin):
         num_running_task_instances_query = session.query(func.count()).filter(
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.state == State.RUNNING,
+            TaskInstance.state == TaskInstanceState.RUNNING,
         )
         if same_dagrun:
             num_running_task_instances_query = 
num_running_task_instances_query.filter(
@@ -2885,7 +2885,7 @@ def _is_further_mapped_inside(operator: Operator, 
container: TaskGroup) -> bool:
 
 # State of the task instance.
 # Stores string version of the task state.
-TaskInstanceStateType = Tuple[TaskInstanceKey, str]
+TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
 
 
 class SimpleTaskInstance:
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index de7afcb4ea..af784d5d32 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -36,7 +36,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.sensors.base import BaseSensorOperator
 from airflow.utils.context import Context
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 
@@ -137,17 +137,17 @@ class SubDagOperator(BaseSensorOperator):
         :param execution_date: Execution date to select task instances.
         """
         with create_session() as session:
-            dag_run.state = State.RUNNING
+            dag_run.state = DagRunState.RUNNING
             session.merge(dag_run)
             failed_task_instances = (
                 session.query(TaskInstance)
                 .filter(TaskInstance.dag_id == self.subdag.dag_id)
                 .filter(TaskInstance.execution_date == execution_date)
-                .filter(TaskInstance.state.in_([State.FAILED, 
State.UPSTREAM_FAILED]))
+                .filter(TaskInstance.state.in_((TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED)))
             )
 
             for task_instance in failed_task_instances:
-                task_instance.state = State.NONE
+                task_instance.state = None
                 session.merge(task_instance)
             session.commit()
 
@@ -164,7 +164,7 @@ class SubDagOperator(BaseSensorOperator):
             dag_run = self.subdag.create_dagrun(
                 run_type=DagRunType.SCHEDULED,
                 execution_date=execution_date,
-                state=State.RUNNING,
+                state=DagRunState.RUNNING,
                 conf=self.conf,
                 external_trigger=True,
                 data_interval=data_interval,
@@ -172,13 +172,13 @@ class SubDagOperator(BaseSensorOperator):
             self.log.info("Created DagRun: %s", dag_run.run_id)
         else:
             self.log.info("Found existing DagRun: %s", dag_run.run_id)
-            if dag_run.state == State.FAILED:
+            if dag_run.state == DagRunState.FAILED:
                 self._reset_dag_run_and_task_instances(dag_run, execution_date)
 
     def poke(self, context: Context):
         execution_date = context["execution_date"]
         dag_run = self._get_dagrun(execution_date=execution_date)
-        return dag_run.state != State.RUNNING
+        return dag_run.state != DagRunState.RUNNING
 
     def post_execute(self, context, result=None):
         super().post_execute(context)
@@ -186,7 +186,7 @@ class SubDagOperator(BaseSensorOperator):
         dag_run = self._get_dagrun(execution_date=execution_date)
         self.log.info("Execution finished. State is %s", dag_run.state)
 
-        if dag_run.state != State.SUCCESS:
+        if dag_run.state != DagRunState.SUCCESS:
             raise AirflowException(f"Expected state: SUCCESS. Actual state: 
{dag_run.state}")
 
         if self.propagate_skipped_state and 
self._check_skipped_states(context):
@@ -196,9 +196,9 @@ class SubDagOperator(BaseSensorOperator):
         leaves_tis = self._get_leaves_tis(context["execution_date"])
 
         if self.propagate_skipped_state == 
SkippedStatePropagationOptions.ANY_LEAF:
-            return any(ti.state == State.SKIPPED for ti in leaves_tis)
+            return any(ti.state == TaskInstanceState.SKIPPED for ti in 
leaves_tis)
         if self.propagate_skipped_state == 
SkippedStatePropagationOptions.ALL_LEAVES:
-            return all(ti.state == State.SKIPPED for ti in leaves_tis)
+            return all(ti.state == TaskInstanceState.SKIPPED for ti in 
leaves_tis)
         raise AirflowException(
             f"Unimplemented SkippedStatePropagationOptions 
{self.propagate_skipped_state} used."
         )
diff --git a/airflow/operators/trigger_dagrun.py 
b/airflow/operators/trigger_dagrun.py
index f3560bffc7..a3a6bf7c7c 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -36,7 +36,7 @@ from airflow.utils import timezone
 from airflow.utils.context import Context
 from airflow.utils.helpers import build_airflow_url_with_query
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
 XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
@@ -116,8 +116,8 @@ class TriggerDagRunOperator(BaseOperator):
         self.reset_dag_run = reset_dag_run
         self.wait_for_completion = wait_for_completion
         self.poke_interval = poke_interval
-        self.allowed_states = allowed_states or [State.SUCCESS]
-        self.failed_states = failed_states or [State.FAILED]
+        self.allowed_states = allowed_states or [DagRunState.SUCCESS]
+        self.failed_states = failed_states or [DagRunState.FAILED]
         self._defer = deferrable
 
         if execution_date is not None and not isinstance(execution_date, (str, 
datetime.datetime)):
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 959ebe5131..158032a2cc 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -37,7 +37,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import build_airflow_url_with_query
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Query, Session
@@ -76,23 +76,25 @@ class ExternalTaskSensor(BaseSensorOperator):
     without also having to clear the sensor).
 
     By default, the ExternalTaskSensor will not skip if the external task 
skips.
-    To change this, simply set ``skipped_states=[State.SKIPPED]``. Note that if
-    you are monitoring multiple tasks, and one enters error state and the other
-    enters a skipped state, then the external task will react to whichever one
-    it sees first. If both happen together, then the failed state takes 
priority.
+    To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+    Note that if you are monitoring multiple tasks, and one enters error state
+    and the other enters a skipped state, then the external task will react to
+    whichever one it sees first. If both happen together, then the failed state
+    takes priority.
 
     It is possible to alter the default behavior by setting states which
-    cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
-    and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
-    sensor which goes green when the external task *fails* and immediately goes
-    red if the external task *succeeds*!
+    cause the sensor to fail, e.g. by setting
+    ``allowed_states=[TaskInstanceState.FAILED]`` and
+    ``failed_states=[TaskInstanceState.SUCCESS]``, you will flip the behaviour
+    to get a sensor which goes green when the external task *fails* and
+    immediately goes red if the external task *succeeds*!
 
     Note that ``soft_fail`` is respected when examining the failed_states. Thus
     if the external task enters a failed state and ``soft_fail == True`` the
     sensor will _skip_ rather than fail. As a result, setting 
``soft_fail=True``
-    and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
-    the external task skips. However, this is a contrived example - consider
-    using ``skipped_states`` if you would like this behaviour. Using
+    and ``failed_states=[TaskInstanceState.SKIPPED]`` will result in the sensor
+    skipping if the external task skips. However, this is a contrived example;
+    consider using ``skipped_states`` if you would like this behaviour. Using
     ``skipped_states`` allows the sensor to skip if the target fails, but still
     enter failed state on timeout. Using ``soft_fail == True`` as above will
     cause the sensor to skip if the target fails, but also if it times out.
@@ -146,7 +148,7 @@ class ExternalTaskSensor(BaseSensorOperator):
         **kwargs,
     ):
         super().__init__(**kwargs)
-        self.allowed_states = list(allowed_states) if allowed_states else 
[State.SUCCESS]
+        self.allowed_states = list(allowed_states) if allowed_states else 
[TaskInstanceState.SUCCESS]
         self.skipped_states = list(skipped_states) if skipped_states else []
         self.failed_states = list(failed_states) if failed_states else []
 
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 8742552e7e..fbc1715eed 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING
 from airflow.configuration import conf
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.utils.session import find_session_idx, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -143,7 +143,7 @@ if conf.getboolean("sentry", "sentry_on", fallback=False):
                 return
             dr = task_instance.get_dagrun(session)
             task_instances = dr.get_task_instances(
-                state={State.SUCCESS, State.FAILED},
+                state={TaskInstanceState.SUCCESS, TaskInstanceState.FAILED},
                 session=session,
             )
 
diff --git a/airflow/ti_deps/dependencies_states.py 
b/airflow/ti_deps/dependencies_states.py
index 543ce3528a..fd25d62f6d 100644
--- a/airflow/ti_deps/dependencies_states.py
+++ b/airflow/ti_deps/dependencies_states.py
@@ -16,38 +16,38 @@
 # under the License.
 from __future__ import annotations
 
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 EXECUTION_STATES = {
-    State.RUNNING,
-    State.QUEUED,
+    TaskInstanceState.RUNNING,
+    TaskInstanceState.QUEUED,
 }
 
 # In order to be able to get queued a task must have one of these states
 SCHEDULEABLE_STATES = {
-    State.NONE,
-    State.UP_FOR_RETRY,
-    State.UP_FOR_RESCHEDULE,
+    None,
+    TaskInstanceState.UP_FOR_RETRY,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
 }
 
 RUNNABLE_STATES = {
     # For cases like unit tests and run manually
-    State.NONE,
-    State.UP_FOR_RETRY,
-    State.UP_FOR_RESCHEDULE,
+    None,
+    TaskInstanceState.UP_FOR_RETRY,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
     # For normal scheduler/backfill cases
-    State.QUEUED,
+    TaskInstanceState.QUEUED,
 }
 
 QUEUEABLE_STATES = {
-    State.SCHEDULED,
+    TaskInstanceState.SCHEDULED,
 }
 
 BACKFILL_QUEUEABLE_STATES = {
     # For cases like unit tests and run manually
-    State.NONE,
-    State.UP_FOR_RESCHEDULE,
-    State.UP_FOR_RETRY,
+    None,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
+    TaskInstanceState.UP_FOR_RETRY,
     # For normal backfill cases
-    State.SCHEDULED,
+    TaskInstanceState.SCHEDULED,
 }
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py 
b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 781ab0ebaf..0a364628c7 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 
 
 class DagrunRunningDep(BaseTIDep):
@@ -31,7 +31,7 @@ class DagrunRunningDep(BaseTIDep):
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
         dr = ti.get_dagrun(session)
-        if dr.state != State.RUNNING:
+        if dr.state != DagRunState.RUNNING:
             yield self._failing_status(
                 reason=f"Task instance's dagrun was not in the 'running' state 
but in the state '{dr.state}'."
             )
diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py 
b/airflow/ti_deps/deps/not_in_retry_period_dep.py
index b3b5d4ec56..90954f29f2 100644
--- a/airflow/ti_deps/deps/not_in_retry_period_dep.py
+++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class NotInRetryPeriodDep(BaseTIDep):
@@ -38,7 +38,7 @@ class NotInRetryPeriodDep(BaseTIDep):
             )
             return
 
-        if ti.state != State.UP_FOR_RETRY:
+        if ti.state != TaskInstanceState.UP_FOR_RETRY:
             yield self._passing_status(reason="The task instance was not 
marked for retrying.")
             return
 
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py 
b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index fdc8274d90..855f04af53 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -40,7 +40,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
             XCOM_SKIPMIXIN_SKIPPED,
             SkipMixin,
         )
-        from airflow.utils.state import State
+        from airflow.utils.state import TaskInstanceState
 
         upstream = ti.task.get_direct_relatives(upstream=True)
 
@@ -87,7 +87,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
                                 reason=("Task should be skipped but the the 
past depends are not met")
                             )
                             return
-                    ti.set_state(State.SKIPPED, session)
+                    ti.set_state(TaskInstanceState.SKIPPED, session)
                     yield self._failing_status(
                         reason=f"Skipping because of previous XCom result from 
parent task {parent.task_id}"
                     )
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py 
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index b3165eb7f2..62acdbca33 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -22,7 +22,7 @@ from sqlalchemy import func
 from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class PrevDagrunDep(BaseTIDep):
@@ -107,7 +107,7 @@ class PrevDagrunDep(BaseTIDep):
             )
             return
 
-        if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
+        if previous_ti.state not in {TaskInstanceState.SKIPPED, 
TaskInstanceState.SUCCESS}:
             yield self._failing_status(
                 reason=(
                     f"depends_on_past is true for this task, but the previous 
task instance {previous_ti} "
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py 
b/airflow/ti_deps/deps/ready_to_reschedule.py
index 66aa5c5613..8394907081 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -22,7 +22,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class ReadyToRescheduleDep(BaseTIDep):
@@ -31,7 +31,7 @@ class ReadyToRescheduleDep(BaseTIDep):
     NAME = "Ready To Reschedule"
     IGNORABLE = True
     IS_TASK_DEP = True
-    RESCHEDULEABLE_STATES = {State.UP_FOR_RESCHEDULE, State.NONE}
+    RESCHEDULEABLE_STATES = {None, TaskInstanceState.UP_FOR_RESCHEDULE}
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
diff --git a/airflow/ti_deps/deps/task_not_running_dep.py 
b/airflow/ti_deps/deps/task_not_running_dep.py
index fd76873466..7299a4f3e2 100644
--- a/airflow/ti_deps/deps/task_not_running_dep.py
+++ b/airflow/ti_deps/deps/task_not_running_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class TaskNotRunningDep(BaseTIDep):
@@ -37,7 +37,7 @@ class TaskNotRunningDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context=None):
-        if ti.state != State.RUNNING:
+        if ti.state != TaskInstanceState.RUNNING:
             yield self._passing_status(reason="Task is not in running state.")
             return
 
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 3f35d1b21d..d9c2acce80 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -58,7 +58,7 @@ def _draw_task(
 ) -> None:
     """Draw a single task on the given parent_graph."""
     if states_by_task_id:
-        state = states_by_task_id.get(task.task_id, State.NONE)
+        state = states_by_task_id.get(task.task_id, None)
         color = State.color_fg(state)
         fill_color = State.color(state)
     else:
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 0c98d25503..9184a20420 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -341,7 +341,10 @@ class FileTaskHandler(logging.Handler):
         )
         log_pos = len(logs)
         messages = "".join([f"*** {x}\n" for x in messages_list])
-        end_of_log = ti.try_number != try_number or ti.state not in 
[State.RUNNING, State.DEFERRED]
+        end_of_log = ti.try_number != try_number or ti.state not in (
+            TaskInstanceState.RUNNING,
+            TaskInstanceState.DEFERRED,
+        )
         if metadata and "log_pos" in metadata:
             previous_chars = metadata["log_pos"]
             logs = logs[previous_chars:]  # Cut off previously passed log test 
as new tail
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index a4589ebca0..e8c3b9897a 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.utils.helpers import render_log_filename
 from airflow.utils.log.logging_mixin import ExternalLoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class TaskLogReader:
@@ -81,19 +81,26 @@ class TaskLogReader:
             metadata.pop("max_offset", None)
             metadata.pop("offset", None)
             metadata.pop("log_pos", None)
+
             while True:
                 logs, metadata = self.read_log_chunks(ti, current_try_number, 
metadata)
                 for host, log in logs[0]:
                     yield "\n".join([host or "", log]) + "\n"
-                if "end_of_log" not in metadata or (
-                    not metadata["end_of_log"] and ti.state not in 
[State.RUNNING, State.DEFERRED]
-                ):
-                    if not logs[0]:
-                        # we did not receive any logs in this loop
-                        # sleeping to conserve resources / limit requests on 
external services
-                        time.sleep(self.STREAM_LOOP_SLEEP_SECONDS)
+                try:
+                    end_of_log = bool(metadata["end_of_log"])
+                except KeyError:
+                    continue_read = True
                 else:
+                    continue_read = not end_of_log and ti.state not in (
+                        TaskInstanceState.RUNNING,
+                        TaskInstanceState.DEFERRED,
+                    )
+                if not continue_read:
                     break
+                if not logs[0]:
+                    # we did not receive any logs in this loop
+                    # sleeping to conserve resources / limit requests on 
external services
+                    time.sleep(self.STREAM_LOOP_SLEEP_SECONDS)
 
     @cached_property
     def log_handler(self):
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index f4a8dc1a0a..b6297dfb71 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -98,14 +98,9 @@ class State:
     finished_dr_states: frozenset[DagRunState] = 
frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
     unfinished_dr_states: frozenset[DagRunState] = 
frozenset([DagRunState.QUEUED, DagRunState.RUNNING])
 
-    task_states: tuple[TaskInstanceState | None, ...] = (None,) + 
tuple(TaskInstanceState)
+    task_states: tuple[TaskInstanceState | None, ...] = (None, 
*TaskInstanceState)
 
-    dag_states: tuple[DagRunState, ...] = (
-        DagRunState.QUEUED,
-        DagRunState.SUCCESS,
-        DagRunState.RUNNING,
-        DagRunState.FAILED,
-    )
+    dag_states: tuple[DagRunState, ...] = tuple(DagRunState)
 
     state_color: dict[TaskInstanceState | None, str] = {
         None: "lightblue",
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 25fc1a28f9..58836beefe 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -85,8 +85,10 @@ def get_instance_with_map(task_instance, session):
     return get_mapped_summary(task_instance, mapped_instances)
 
 
-def get_try_count(try_number: int, state: State):
-    return try_number + 1 if state in [State.DEFERRED, 
State.UP_FOR_RESCHEDULE] else try_number
+def get_try_count(try_number: int, state: TaskInstanceState) -> int:
+    if state in (TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE):
+        return try_number + 1
+    return try_number
 
 
 priority: list[None | TaskInstanceState] = [
@@ -426,7 +428,7 @@ def task_instance_link(attr):
 
 
 def state_token(state):
-    """Returns a formatted string with HTML for a given State."""
+    """Returns a formatted string with HTML for a given state."""
     color = State.color(state)
     fg_color = State.color_fg(state)
     return Markup(
@@ -438,7 +440,7 @@ def state_token(state):
 
 
 def state_f(attr):
-    """Gets 'state' & returns a formatted string with HTML for a given 
State."""
+    """Gets 'state' & returns a formatted string with HTML for a given 
state."""
     state = attr.get("state")
     return state_token(state)
 
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f50df518fe..3da11f5681 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -754,7 +754,7 @@ class Airflow(AirflowBaseView):
 
             # find DAGs which have a RUNNING DagRun
             running_dags = dags_query.join(DagRun, DagModel.dag_id == 
DagRun.dag_id).filter(
-                DagRun.state == State.RUNNING
+                DagRun.state == DagRunState.RUNNING
             )
 
             # find DAGs for which the latest DagRun is FAILED
@@ -765,7 +765,7 @@ class Airflow(AirflowBaseView):
             )
             subq_failed = (
                 session.query(DagRun.dag_id, 
func.max(DagRun.start_date).label("start_date"))
-                .filter(DagRun.state == State.FAILED)
+                .filter(DagRun.state == DagRunState.FAILED)
                 .group_by(DagRun.dag_id)
                 .subquery()
             )
@@ -1101,7 +1101,7 @@ class Airflow(AirflowBaseView):
         running_dag_run_query_result = (
             session.query(DagRun.dag_id, DagRun.run_id)
             .join(DagModel, DagModel.dag_id == DagRun.dag_id)
-            .filter(DagRun.state == State.RUNNING, DagModel.is_active)
+            .filter(DagRun.state == DagRunState.RUNNING, DagModel.is_active)
         )
 
         running_dag_run_query_result = 
running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids))
@@ -1125,7 +1125,7 @@ class Airflow(AirflowBaseView):
             last_dag_run = (
                 session.query(DagRun.dag_id, 
sqla.func.max(DagRun.execution_date).label("execution_date"))
                 .join(DagModel, DagModel.dag_id == DagRun.dag_id)
-                .filter(DagRun.state != State.RUNNING, DagModel.is_active)
+                .filter(DagRun.state != DagRunState.RUNNING, 
DagModel.is_active)
                 .group_by(DagRun.dag_id)
             )
 
@@ -1820,7 +1820,7 @@ class Airflow(AirflowBaseView):
                 "Airflow administrator for assistance.".format(
                     "- This task instance already ran and had it's state 
changed manually "
                     "(e.g. cleared in the UI)<br>"
-                    if ti and ti.state == State.NONE
+                    if ti and ti.state is None
                     else ""
                 ),
             )
@@ -2119,7 +2119,7 @@ class Airflow(AirflowBaseView):
                 run_type=DagRunType.MANUAL,
                 execution_date=execution_date,
                 
data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date),
-                state=State.QUEUED,
+                state=DagRunState.QUEUED,
                 conf=run_conf,
                 external_trigger=True,
                 dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
@@ -5337,14 +5337,14 @@ class 
DagRunModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_queued(self, drs: list[DagRun]):
         """Set state to queued."""
-        return self._set_dag_runs_to_active_state(drs, State.QUEUED)
+        return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED)
 
     @action("set_running", "Set state to 'running'", "", single=False)
     @action_has_dag_edit_access
     @action_logging
     def action_set_running(self, drs: list[DagRun]):
         """Set state to running."""
-        return self._set_dag_runs_to_active_state(drs, State.RUNNING)
+        return self._set_dag_runs_to_active_state(drs, DagRunState.RUNNING)
 
     @provide_session
     def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, 
session: Session = NEW_SESSION):
@@ -5353,7 +5353,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
             count = 0
             for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for 
dagrun in drs)):
                 count += 1
-                if state == State.RUNNING:
+                if state == DagRunState.RUNNING:
                     dr.start_date = timezone.utcnow()
                 dr.state = state
             session.commit()
@@ -5790,7 +5790,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_running(self, tis):
         """Set state to 'running'."""
-        self.set_task_instance_state(tis, State.RUNNING)
+        self.set_task_instance_state(tis, TaskInstanceState.RUNNING)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5799,7 +5799,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_failed(self, tis):
         """Set state to 'failed'."""
-        self.set_task_instance_state(tis, State.FAILED)
+        self.set_task_instance_state(tis, TaskInstanceState.FAILED)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5808,7 +5808,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_success(self, tis):
         """Set state to 'success'."""
-        self.set_task_instance_state(tis, State.SUCCESS)
+        self.set_task_instance_state(tis, TaskInstanceState.SUCCESS)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5817,7 +5817,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_retry(self, tis):
         """Set state to 'up_for_retry'."""
-        self.set_task_instance_state(tis, State.UP_FOR_RETRY)
+        self.set_task_instance_state(tis, TaskInstanceState.UP_FOR_RETRY)
         self.update_redirect()
         return redirect(self.get_redirect())
 


Reply via email to