This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fa3be084a1 More strong typed state conversion (#32521)
fa3be084a1 is described below

commit fa3be084a1cd84242893a3367ac2c0c4d3a4f480
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jul 12 05:54:36 2023 +0800

    More strong typed state conversion (#32521)
---
 airflow/api/common/delete_dag.py                   |  4 +-
 airflow/api/common/mark_tasks.py                   | 35 +++++++++++----
 .../endpoints/task_instance_endpoint.py            | 12 +++---
 airflow/dag_processing/processor.py                | 10 +++--
 airflow/executors/local_executor.py                | 16 +++----
 airflow/jobs/backfill_job_runner.py                |  6 +--
 airflow/jobs/local_task_job_runner.py              |  6 +--
 airflow/jobs/scheduler_job_runner.py               | 18 ++++----
 airflow/listeners/spec/taskinstance.py             | 12 +++---
 airflow/models/dag.py                              | 14 +++---
 airflow/models/dagrun.py                           | 42 +++++++++---------
 airflow/models/pool.py                             |  8 ++--
 airflow/models/skipmixin.py                        |  4 +-
 airflow/models/taskinstance.py                     | 50 +++++++++++-----------
 airflow/operators/subdag.py                        | 20 ++++-----
 airflow/sentry.py                                  | 21 +++++----
 airflow/ti_deps/deps/dagrun_exists_dep.py          |  4 +-
 airflow/ti_deps/deps/not_in_retry_period_dep.py    |  4 +-
 airflow/ti_deps/deps/not_previously_skipped_dep.py |  4 +-
 airflow/ti_deps/deps/prev_dagrun_dep.py            |  4 +-
 airflow/ti_deps/deps/ready_to_reschedule.py        |  4 +-
 airflow/ti_deps/deps/task_not_running_dep.py       |  4 +-
 airflow/utils/dot_renderer.py                      |  2 +-
 airflow/utils/log/file_task_handler.py             |  5 ++-
 airflow/utils/log/log_reader.py                    |  5 ++-
 airflow/utils/state.py                             |  6 +--
 airflow/www/utils.py                               |  4 +-
 airflow/www/views.py                               | 40 ++++++++++-------
 dev/perf/scheduler_dag_execution_timing.py         |  8 ++--
 29 files changed, 204 insertions(+), 168 deletions(-)

diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py
index 45611729ea..1d879a667a 100644
--- a/airflow/api/common/delete_dag.py
+++ b/airflow/api/common/delete_dag.py
@@ -29,7 +29,7 @@ from airflow.models import DagModel, TaskFail
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.utils.db import get_sqla_model_classes
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 log = logging.getLogger(__name__)
 
@@ -50,7 +50,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, 
session: Session =
     running_tis = session.scalar(
         select(models.TaskInstance.state)
         .where(models.TaskInstance.dag_id == dag_id)
-        .where(models.TaskInstance.state == State.RUNNING)
+        .where(models.TaskInstance.state == TaskInstanceState.RUNNING)
         .limit(1)
     )
     if running_tis:
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index b965237bdf..4d2df78e82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -155,7 +155,7 @@ def set_state(
         for task_instance in tis_altered:
             # The try_number was decremented when setting to up_for_reschedule 
and deferred.
             # Increment it back when changing the state again
-            if task_instance.state in [State.DEFERRED, 
State.UP_FOR_RESCHEDULE]:
+            if task_instance.state in (TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE):
                 task_instance._try_number += 1
             task_instance.set_state(state, session=session)
         session.flush()
@@ -362,7 +362,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: 
DagRunState, session: SA
         select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
     ).scalar_one()
     dag_run.state = state
-    if state == State.RUNNING:
+    if state == DagRunState.RUNNING:
         dag_run.start_date = timezone.utcnow()
         dag_run.end_date = None
     else:
@@ -415,7 +415,13 @@ def set_dag_run_state_to_success(
     # Mark all task instances of the dag run to success.
     for task in dag.tasks:
         task.dag = dag
-    return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS, 
commit=commit, session=session)
+    return set_state(
+        tasks=dag.tasks,
+        run_id=run_id,
+        state=TaskInstanceState.SUCCESS,
+        commit=commit,
+        session=session,
+    )
 
 
 @provide_session
@@ -461,6 +467,12 @@ def set_dag_run_state_to_failed(
     if commit:
         _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)
 
+    running_states = (
+        TaskInstanceState.RUNNING,
+        TaskInstanceState.DEFERRED,
+        TaskInstanceState.UP_FOR_RESCHEDULE,
+    )
+
     # Mark only RUNNING task instances.
     task_ids = [task.task_id for task in dag.tasks]
     tis = session.scalars(
@@ -468,7 +480,7 @@ def set_dag_run_state_to_failed(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
             TaskInstance.task_id.in_(task_ids),
-            TaskInstance.state.in_([State.RUNNING, State.DEFERRED, 
State.UP_FOR_RESCHEDULE]),
+            TaskInstance.state.in_(running_states),
         )
     )
 
@@ -487,16 +499,21 @@ def set_dag_run_state_to_failed(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
             TaskInstance.state.not_in(State.finished),
-            TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, 
State.UP_FOR_RESCHEDULE]),
+            TaskInstance.state.not_in(running_states),
         )
-    )
+    ).all()
 
-    tis = [ti for ti in tis]
     if commit:
         for ti in tis:
-            ti.set_state(State.SKIPPED)
+            ti.set_state(TaskInstanceState.SKIPPED)
 
-    return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, 
commit=commit, session=session)
+    return tis + set_state(
+        tasks=tasks,
+        run_id=run_id,
+        state=TaskInstanceState.FAILED,
+        commit=commit,
+        session=session,
+    )
 
 
 def __set_dag_run_state_to_running_or_queued(
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 3028b0bb73..2496433a26 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -49,7 +49,7 @@ from airflow.models.taskinstance import TaskInstance as TI, 
clear_task_instances
 from airflow.security import permissions
 from airflow.utils.airflow_flask_app import get_airflow_app
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState, TaskInstanceState
 
 T = TypeVar("T")
 
@@ -187,7 +187,7 @@ def get_mapped_task_instances(
 ) -> APIResponse:
     """Get list of task instances."""
     # Because state can be 'none'
-    states = _convert_state(state)
+    states = _convert_ti_states(state)
 
     base_query = (
         select(TI)
@@ -264,10 +264,10 @@ def get_mapped_task_instances(
     )
 
 
-def _convert_state(states: Iterable[str] | None) -> list[str | None] | None:
+def _convert_ti_states(states: Iterable[str] | None) -> list[TaskInstanceState 
| None] | None:
     if not states:
         return None
-    return [State.NONE if s == "none" else s for s in states]
+    return [None if s == "none" else TaskInstanceState(s) for s in states]
 
 
 def _apply_array_filter(query: Select, key: ClauseElement, values: 
Iterable[Any] | None) -> Select:
@@ -329,7 +329,7 @@ def get_task_instances(
 ) -> APIResponse:
     """Get list of task instances."""
     # Because state can be 'none'
-    states = _convert_state(state)
+    states = _convert_ti_states(state)
 
     base_query = select(TI).join(TI.dag_run)
 
@@ -395,7 +395,7 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
         data = task_instance_batch_form.load(body)
     except ValidationError as err:
         raise BadRequest(detail=str(err.messages))
-    states = _convert_state(data["state"])
+    states = _convert_ti_states(data["state"])
     base_query = select(TI).join(TI.dag_run)
 
     base_query = _apply_array_filter(base_query, key=TI.dag_id, 
values=data["dag_ids"])
diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index 5c175571c0..369f676878 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -56,7 +56,7 @@ from airflow.utils.file import iter_airflow_imports, 
might_contain_dag
 from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, 
set_context
 from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.models.operator import Operator
@@ -433,7 +433,7 @@ class DagFileProcessor(LoggingMixin):
             session.query(TI.task_id, 
func.max(DR.execution_date).label("max_ti"))
             .join(TI.dag_run)
             .filter(TI.dag_id == dag.dag_id)
-            .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
+            .filter(or_(TI.state == TaskInstanceState.SUCCESS, TI.state == 
TaskInstanceState.SKIPPED))
             .filter(TI.task_id.in_(dag.task_ids))
             .group_by(TI.task_id)
             .subquery("sq")
@@ -500,7 +500,11 @@ class DagFileProcessor(LoggingMixin):
             sla_dates: list[datetime] = [sla.execution_date for sla in slas]
             fetched_tis: list[TI] = (
                 session.query(TI)
-                .filter(TI.state != State.SUCCESS, 
TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+                .filter(
+                    TI.dag_id == dag.dag_id,
+                    TI.execution_date.in_(sla_dates),
+                    TI.state != TaskInstanceState.SUCCESS,
+                )
                 .all()
             )
             blocking_tis: list[TI] = []
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index ca54a387c8..7f83f8c7a2 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -39,7 +39,7 @@ from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import PARALLELISM, BaseExecutor
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
@@ -94,20 +94,20 @@ class LocalWorkerBase(Process, LoggingMixin):
         # Remove the command since the worker is done executing the task
         setproctitle("airflow worker -- LocalExecutor")
 
-    def _execute_work_in_subprocess(self, command: CommandType) -> str:
+    def _execute_work_in_subprocess(self, command: CommandType) -> 
TaskInstanceState:
         try:
             subprocess.check_call(command, close_fds=True)
-            return State.SUCCESS
+            return TaskInstanceState.SUCCESS
         except subprocess.CalledProcessError as e:
             self.log.error("Failed to execute task %s.", str(e))
-            return State.FAILED
+            return TaskInstanceState.FAILED
 
-    def _execute_work_in_fork(self, command: CommandType) -> str:
+    def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState:
         pid = os.fork()
         if pid:
             # In parent, wait for the child
             pid, ret = os.waitpid(pid, 0)
-            return State.SUCCESS if ret == 0 else State.FAILED
+            return TaskInstanceState.SUCCESS if ret == 0 else 
TaskInstanceState.FAILED
 
         from airflow.sentry import Sentry
 
@@ -130,10 +130,10 @@ class LocalWorkerBase(Process, LoggingMixin):
 
             args.func(args)
             ret = 0
-            return State.SUCCESS
+            return TaskInstanceState.SUCCESS
         except Exception as e:
             self.log.exception("Failed to execute task %s.", e)
-            return State.FAILED
+            return TaskInstanceState.FAILED
         finally:
             Sentry.flush()
             logging.shutdown()
diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index 5b13490be7..35910b83ae 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -68,7 +68,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
 
     job_type = "BackfillJob"
 
-    STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
+    STATES_COUNT_AS_RUNNING = (TaskInstanceState.RUNNING, 
TaskInstanceState.QUEUED)
 
     @attr.define
     class _DagRunTaskStatus:
@@ -219,7 +219,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
             # is changed externally, e.g. by clearing tasks from the ui. We 
need to cover
             # for that as otherwise those tasks would fall outside the scope of
             # the backfill suddenly.
-            elif ti.state == State.NONE:
+            elif ti.state is None:
                 self.log.warning(
                     "FIXME: task instance %s state was set to none externally 
or "
                     "reaching concurrency limits. Re-adding task to queue.",
@@ -1000,7 +1000,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
             ).all()
 
             for ti in reset_tis:
-                ti.state = State.NONE
+                ti.state = None
                 session.merge(ti)
 
             return result + reset_tis
diff --git a/airflow/jobs/local_task_job_runner.py 
b/airflow/jobs/local_task_job_runner.py
index 9f6a4b55e8..fd234e4150 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -35,7 +35,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import IS_WINDOWS
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 SIGSEGV_MESSAGE = """
 ******************************************* Received SIGSEGV 
*******************************************
@@ -243,7 +243,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | 
JobPydantic"], LoggingMixin):
         self.task_instance.refresh_from_db()
         ti = self.task_instance
 
-        if ti.state == State.RUNNING:
+        if ti.state == TaskInstanceState.RUNNING:
             fqdn = get_hostname()
             same_hostname = fqdn == ti.hostname
             if not same_hostname:
@@ -273,7 +273,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job | 
JobPydantic"], LoggingMixin):
                 )
                 raise AirflowException("PID of job runner does not match")
         elif self.task_runner.return_code() is None and 
hasattr(self.task_runner, "process"):
-            if ti.state == State.SKIPPED:
+            if ti.state == TaskInstanceState.SKIPPED:
                 # A DagRun timeout will cause tasks to be externally marked as 
skipped.
                 dagrun = ti.get_dagrun(session=session)
                 execution_time = (dagrun.end_date or timezone.utcnow()) - 
dagrun.start_date
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 8e9db65647..7b13ecc300 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -610,7 +610,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
             )
 
             for ti in executable_tis:
-                ti.emit_state_change_metric(State.QUEUED)
+                ti.emit_state_change_metric(TaskInstanceState.QUEUED)
 
         for ti in executable_tis:
             make_transient(ti)
@@ -626,7 +626,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
         # actually enqueue them
         for ti in task_instances:
             if ti.dag_run.state in State.finished:
-                ti.set_state(State.NONE, session=session)
+                ti.set_state(None, session=session)
                 continue
             command = ti.command_as_list(
                 local=True,
@@ -681,14 +681,12 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
         tis_with_right_state: list[TaskInstanceKey] = []
 
         # Report execution
-        for ti_key, value in event_buffer.items():
-            state: str
-            state, _ = value
+        for ti_key, (state, _) in event_buffer.items():
             # We create map (dag_id, task_id, execution_date) -> in-memory 
try_number
             ti_primary_key_to_try_number_map[ti_key.primary] = 
ti_key.try_number
 
             self.log.info("Received executor event with state %s for task 
instance %s", state, ti_key)
-            if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+            if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, 
TaskInstanceState.QUEUED):
                 tis_with_right_state.append(ti_key)
 
         # Return if no finished tasks
@@ -712,7 +710,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
             buffer_key = ti.key.with_try_number(try_number)
             state, info = event_buffer.pop(buffer_key)
 
-            if state == State.QUEUED:
+            if state == TaskInstanceState.QUEUED:
                 ti.external_executor_id = info
                 self.log.info("Setting external_id for %s to %s", ti, info)
                 continue
@@ -1532,7 +1530,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
 
         tasks_stuck_in_queued = session.scalars(
             select(TI).where(
-                TI.state == State.QUEUED,
+                TI.state == TaskInstanceState.QUEUED,
                 TI.queued_dttm < (timezone.utcnow() - 
timedelta(seconds=self._task_queued_timeout)),
                 TI.queued_by_job_id == self.job.id,
             )
@@ -1611,7 +1609,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                         .join(TI.dag_run)
                         .where(
                             DagRun.run_type != DagRunType.BACKFILL_JOB,
-                            DagRun.state == State.RUNNING,
+                            DagRun.state == DagRunState.RUNNING,
                         )
                         .options(load_only(TI.dag_id, TI.task_id, TI.run_id))
                     )
@@ -1626,7 +1624,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                     reset_tis_message = []
                     for ti in to_reset:
                         reset_tis_message.append(repr(ti))
-                        ti.state = State.NONE
+                        ti.state = None
                         ti.queued_by_job_id = None
 
                     for ti in set(tis_to_reset_or_adopt) - set(to_reset):
diff --git a/airflow/listeners/spec/taskinstance.py 
b/airflow/listeners/spec/taskinstance.py
index 78de8a5f62..b87043a99d 100644
--- a/airflow/listeners/spec/taskinstance.py
+++ b/airflow/listeners/spec/taskinstance.py
@@ -32,20 +32,20 @@ hookspec = HookspecMarker("airflow")
 
 @hookspec
 def on_task_instance_running(
-    previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
+    previous_state: TaskInstanceState | None, task_instance: TaskInstance, 
session: Session | None
 ):
-    """Called when task state changes to RUNNING. Previous_state can be 
State.NONE."""
+    """Called when task state changes to RUNNING. previous_state can be 
None."""
 
 
 @hookspec
 def on_task_instance_success(
-    previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
+    previous_state: TaskInstanceState | None, task_instance: TaskInstance, 
session: Session | None
 ):
-    """Called when task state changes to SUCCESS. Previous_state can be 
State.NONE."""
+    """Called when task state changes to SUCCESS. previous_state can be 
None."""
 
 
 @hookspec
 def on_task_instance_failed(
-    previous_state: TaskInstanceState, task_instance: TaskInstance, session: 
Session | None
+    previous_state: TaskInstanceState | None, task_instance: TaskInstance, 
session: Session | None
 ):
-    """Called when task state changes to FAIL. Previous_state can be 
State.NONE."""
+    """Called when task state changes to FAIL. previous_state can be None."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 46e7f4608e..391be9e582 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -127,7 +127,7 @@ from airflow.utils.sqlalchemy import (
     tuple_in_condition,
     with_row_locks,
 )
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -1416,7 +1416,7 @@ class DAG(LoggingMixin):
 
         :return: List of execution dates
         """
-        runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
+        runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING)
 
         active_dates = []
         for run in runs:
@@ -2118,7 +2118,7 @@ class DAG(LoggingMixin):
     @provide_session
     def set_dag_runs_state(
         self,
-        state: str = State.RUNNING,
+        state: DagRunState = DagRunState.RUNNING,
         session: Session = NEW_SESSION,
         start_date: datetime | None = None,
         end_date: datetime | None = None,
@@ -2199,12 +2199,12 @@ class DAG(LoggingMixin):
                 stacklevel=2,
             )
 
-        state = []
+        state: list[TaskInstanceState] = []
         if only_failed:
-            state += [State.FAILED, State.UPSTREAM_FAILED]
+            state += [TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED]
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing 
behaviour
-            state += [State.RUNNING]
+            state += [TaskInstanceState.RUNNING]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
@@ -2742,7 +2742,7 @@ class DAG(LoggingMixin):
         # Instead of starting a scheduler, we run the minimal loop possible to 
check
         # for task readiness and dependency management. This is notably faster
         # than creating a BackfillJob and allows us to surface logs to the user
-        while dr.state == State.RUNNING:
+        while dr.state == DagRunState.RUNNING:
             schedulable_tis, _ = dr.update_state(session=session)
             try:
                 for ti in schedulable_tis:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index f0b29881a8..e3ed0bda00 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -115,7 +115,7 @@ class DagRun(Base, LoggingMixin):
     execution_date = Column(UtcDateTime, default=timezone.utcnow, 
nullable=False)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
-    _state = Column("state", String(50), default=State.QUEUED)
+    _state = Column("state", String(50), default=DagRunState.QUEUED)
     run_id = Column(StringID(), nullable=False)
     creating_job_id = Column(Integer)
     external_trigger = Column(Boolean, default=True)
@@ -222,7 +222,7 @@ class DagRun(Base, LoggingMixin):
         if state is not None:
             self.state = state
         if queued_at is NOTSET:
-            self.queued_at = timezone.utcnow() if state == State.QUEUED else 
None
+            self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED 
else None
         else:
             self.queued_at = queued_at
         self.run_type = run_type
@@ -265,13 +265,13 @@ class DagRun(Base, LoggingMixin):
     def get_state(self):
         return self._state
 
-    def set_state(self, state: DagRunState):
+    def set_state(self, state: DagRunState) -> None:
         if state not in State.dag_states:
             raise ValueError(f"invalid DagRun state: {state}")
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished 
else None
-            if state == State.QUEUED:
+            if state == DagRunState.QUEUED:
                 self.queued_at = timezone.utcnow()
 
     @declared_attr
@@ -306,9 +306,9 @@ class DagRun(Base, LoggingMixin):
             # because SQLAlchemy doesn't accept a set here.
             query = query.where(cls.dag_id.in_(set(dag_ids)))
         if only_running:
-            query = query.where(cls.state == State.RUNNING)
+            query = query.where(cls.state == DagRunState.RUNNING)
         else:
-            query = query.where(cls.state.in_([State.RUNNING, State.QUEUED]))
+            query = query.where(cls.state.in_((DagRunState.RUNNING, 
DagRunState.QUEUED)))
         query = query.group_by(cls.dag_id)
         return {dag_id: count for dag_id, count in session.execute(query)}
 
@@ -340,7 +340,7 @@ class DagRun(Base, LoggingMixin):
             .join(DagModel, DagModel.dag_id == cls.dag_id)
             .where(DagModel.is_paused == false(), DagModel.is_active == true())
         )
-        if state == State.QUEUED:
+        if state == DagRunState.QUEUED:
             # For dag runs in the queued state, we check if they have reached 
the max_active_runs limit
             # and if so we drop them
             running_drs = (
@@ -477,11 +477,11 @@ class DagRun(Base, LoggingMixin):
                 tis = tis.where(TI.state == state)
             else:
                 # this is required to deal with NULL values
-                if State.NONE in state:
+                if None in state:
                     if all(x is None for x in state):
                         tis = tis.where(TI.state.is_(None))
                     else:
-                        not_none_state = [s for s in state if s]
+                        not_none_state = (s for s in state if s)
                         tis = tis.where(or_(TI.state.in_(not_none_state), 
TI.state.is_(None)))
                 else:
                     tis = tis.where(TI.state.in_(state))
@@ -746,9 +746,9 @@ class DagRun(Base, LoggingMixin):
                 try:
                     ti.task = dag.get_task(ti.task_id)
                 except TaskNotFound:
-                    if ti.state != State.REMOVED:
+                    if ti.state != TaskInstanceState.REMOVED:
                         self.log.error("Failed to get task for ti %s. Marking 
it as removed.", ti)
-                        ti.state = State.REMOVED
+                        ti.state = TaskInstanceState.REMOVED
                         session.flush()
                 else:
                     yield ti
@@ -957,7 +957,7 @@ class DagRun(Base, LoggingMixin):
             self.log.warning("Failed to record first_task_scheduling_delay 
metric:", exc_info=True)
 
     def _emit_duration_stats_for_finished_state(self):
-        if self.state == State.RUNNING:
+        if self.state == DagRunState.RUNNING:
             return
         if self.start_date is None:
             self.log.warning("Failed to record duration of %s: start_date is 
not set.", self)
@@ -1029,22 +1029,22 @@ class DagRun(Base, LoggingMixin):
             try:
                 task = dag.get_task(ti.task_id)
 
-                should_restore_task = (task is not None) and ti.state == 
State.REMOVED
+                should_restore_task = (task is not None) and ti.state == 
TaskInstanceState.REMOVED
                 if should_restore_task:
                     self.log.info("Restoring task '%s' which was previously 
removed from DAG '%s'", ti, dag)
                     Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 
tags=self.stats_tags)
                     # Same metric with tagging
                     Stats.incr("task_restored_to_dag", 
tags={**self.stats_tags, "dag_id": dag.dag_id})
-                    ti.state = State.NONE
+                    ti.state = None
             except AirflowException:
-                if ti.state == State.REMOVED:
+                if ti.state == TaskInstanceState.REMOVED:
                     pass  # ti has already been removed, just ignore it
-                elif self.state != State.RUNNING and not dag.partial:
+                elif self.state != DagRunState.RUNNING and not dag.partial:
                     self.log.warning("Failed to get task '%s' for dag '%s'. 
Marking it as removed.", ti, dag)
                     Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 
tags=self.stats_tags)
                     # Same metric with tagging
                     Stats.incr("task_removed_from_dag", 
tags={**self.stats_tags, "dag_id": dag.dag_id})
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
                 continue
 
             try:
@@ -1061,7 +1061,7 @@ class DagRun(Base, LoggingMixin):
                         self.log.debug(
                             "Removing the unmapped TI '%s' as the mapping 
can't be resolved yet", ti
                         )
-                        ti.state = State.REMOVED
+                        ti.state = TaskInstanceState.REMOVED
                     continue
                 # Upstreams finished, check there aren't any extras
                 if ti.map_index >= total_length:
@@ -1070,7 +1070,7 @@ class DagRun(Base, LoggingMixin):
                         ti,
                         total_length,
                     )
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
             else:
                 # Check if the number of mapped literals has changed, and we 
need to mark this TI as removed.
                 if ti.map_index >= num_mapped_tis:
@@ -1079,10 +1079,10 @@ class DagRun(Base, LoggingMixin):
                         ti,
                         num_mapped_tis,
                     )
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
                 elif ti.map_index < 0:
                     self.log.debug("Removing the unmapped TI '%s' as the 
mapping can now be performed", ti)
-                    ti.state = State.REMOVED
+                    ti.state = TaskInstanceState.REMOVED
 
         return task_ids
 
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 60f92506f6..83ec0368bd 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -28,7 +28,7 @@ from airflow.ti_deps.dependencies_states import 
EXECUTION_STATES
 from airflow.typing_compat import TypedDict
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import nowait, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class PoolStats(TypedDict):
@@ -247,7 +247,7 @@ class Pool(Base):
             session.scalar(
                 select(func.sum(TaskInstance.pool_slots))
                 .filter(TaskInstance.pool == self.pool)
-                .filter(TaskInstance.state == State.RUNNING)
+                .filter(TaskInstance.state == TaskInstanceState.RUNNING)
             )
             or 0
         )
@@ -266,7 +266,7 @@ class Pool(Base):
             session.scalar(
                 select(func.sum(TaskInstance.pool_slots))
                 .filter(TaskInstance.pool == self.pool)
-                .filter(TaskInstance.state == State.QUEUED)
+                .filter(TaskInstance.state == TaskInstanceState.QUEUED)
             )
             or 0
         )
@@ -285,7 +285,7 @@ class Pool(Base):
             session.scalar(
                 select(func.sum(TaskInstance.pool_slots))
                 .filter(TaskInstance.pool == self.pool)
-                .filter(TaskInstance.state == State.SCHEDULED)
+                .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
             )
             or 0
         )
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 849083e38b..10991cadc7 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -28,7 +28,7 @@ from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from pendulum import DateTime
@@ -79,7 +79,7 @@ class SkipMixin(LoggingMixin):
 
             query.update(
                 {
-                    TaskInstance.state: State.SKIPPED,
+                    TaskInstance.state: TaskInstanceState.SKIPPED,
                     TaskInstance.start_date: now,
                     TaskInstance.end_date: now,
                 },
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 65eeacdfd3..ee1825c063 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -593,7 +593,7 @@ class TaskInstance(Base, LoggingMixin):
         database, in all other cases this will be incremented.
         """
         # This is designed so that task logs end up in the right file.
-        if self.state == State.RUNNING:
+        if self.state == TaskInstanceState.RUNNING:
             return self._try_number
         return self._try_number + 1
 
@@ -811,7 +811,7 @@ class TaskInstance(Base, LoggingMixin):
         :param session: SQLAlchemy ORM Session
         """
         self.log.error("Recording the task instance as FAILED")
-        self.state = State.FAILED
+        self.state = TaskInstanceState.FAILED
         session.merge(self)
         session.commit()
 
@@ -940,7 +940,7 @@ class TaskInstance(Base, LoggingMixin):
         self.log.debug("Setting task state for %s to %s", self, state)
         self.state = state
         self.start_date = self.start_date or current_time
-        if self.state in State.finished or self.state == State.UP_FOR_RETRY:
+        if self.state in State.finished or self.state == 
TaskInstanceState.UP_FOR_RETRY:
             self.end_date = self.end_date or current_time
             self.duration = (self.end_date - self.start_date).total_seconds()
         session.merge(self)
@@ -953,7 +953,7 @@ class TaskInstance(Base, LoggingMixin):
         has elapsed.
         """
         # is the task still in the retry waiting period?
-        return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
+        return self.state == TaskInstanceState.UP_FOR_RETRY and not 
self.ready_for_retry()
 
     @provide_session
     def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
@@ -976,7 +976,7 @@ class TaskInstance(Base, LoggingMixin):
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id.in_(task.downstream_task_ids),
             TaskInstance.run_id == self.run_id,
-            TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
+            TaskInstance.state.in_((TaskInstanceState.SKIPPED, 
TaskInstanceState.SUCCESS)),
         )
         count = ti[0][0]
         return count == len(task.downstream_task_ids)
@@ -1213,7 +1213,7 @@ class TaskInstance(Base, LoggingMixin):
         Checks on whether the task instance is in the right state and timeframe
         to be retried.
         """
-        return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() 
< timezone.utcnow()
+        return self.state == TaskInstanceState.UP_FOR_RETRY and 
self.next_retry_datetime() < timezone.utcnow()
 
     @provide_session
     def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
@@ -1282,7 +1282,7 @@ class TaskInstance(Base, LoggingMixin):
         self.hostname = get_hostname()
         self.pid = None
 
-        if not ignore_all_deps and not ignore_ti_state and self.state == 
State.SUCCESS:
+        if not ignore_all_deps and not ignore_ti_state and self.state == 
TaskInstanceState.SUCCESS:
             Stats.incr("previously_succeeded", tags=self.stats_tags)
 
         if not mark_success:
@@ -1310,7 +1310,7 @@ class TaskInstance(Base, LoggingMixin):
             # start date that is recorded in task_reschedule table
             # If the task continues after being deferred (next_method is set), 
use the original start_date
             self.start_date = self.start_date if self.next_method else 
timezone.utcnow()
-            if self.state == State.UP_FOR_RESCHEDULE:
+            if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                 task_reschedule: TR = TR.query_for_task_instance(self, 
session=session).first()
                 if task_reschedule:
                     self.start_date = task_reschedule.start_date
@@ -1328,7 +1328,7 @@ class TaskInstance(Base, LoggingMixin):
                 description="requeueable deps",
             )
             if not self.are_dependencies_met(dep_context=dep_context, 
session=session, verbose=True):
-                self.state = State.NONE
+                self.state = None
                 self.log.warning(
                     "Rescheduling due to concurrency limits reached "
                     "at task runtime. Attempt %s of "
@@ -1350,8 +1350,8 @@ class TaskInstance(Base, LoggingMixin):
         if not test_mode:
             session.add(Log(State.RUNNING, self))
 
-        self.state = State.RUNNING
-        self.emit_state_change_metric(State.RUNNING)
+        self.state = TaskInstanceState.RUNNING
+        self.emit_state_change_metric(TaskInstanceState.RUNNING)
         self.external_executor_id = external_executor_id
         self.end_date = None
         if not test_mode:
@@ -1391,7 +1391,7 @@ class TaskInstance(Base, LoggingMixin):
             self._date_or_empty("end_date"),
         )
 
-    def emit_state_change_metric(self, new_state: TaskInstanceState):
+    def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
         """
         Sends a time metric representing how much time a given state 
transition took.
         The previous state and metric name is deduced from the state the task 
was put in.
@@ -1407,7 +1407,7 @@ class TaskInstance(Base, LoggingMixin):
             return
 
         # switch on state and deduce which metric to send
-        if new_state == State.RUNNING:
+        if new_state == TaskInstanceState.RUNNING:
             metric_name = "queued_duration"
             if self.queued_dttm is None:
                 # this should not really happen except in tests or rare cases,
@@ -1419,7 +1419,7 @@ class TaskInstance(Base, LoggingMixin):
                 )
                 return
             timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
-        elif new_state == State.QUEUED:
+        elif new_state == TaskInstanceState.QUEUED:
             metric_name = "scheduled_duration"
             if self.start_date is None:
                 # same comment as above
@@ -1506,7 +1506,7 @@ class TaskInstance(Base, LoggingMixin):
                 self._execute_task_with_callbacks(context, test_mode)
             if not test_mode:
                 self.refresh_from_db(lock_for_update=True, session=session)
-            self.state = State.SUCCESS
+            self.state = TaskInstanceState.SUCCESS
         except TaskDeferred as defer:
             # The task has signalled it wants to defer execution based on
             # a trigger.
@@ -1530,7 +1530,7 @@ class TaskInstance(Base, LoggingMixin):
                 self.log.info(e)
             if not test_mode:
                 self.refresh_from_db(lock_for_update=True, session=session)
-            self.state = State.SKIPPED
+            self.state = TaskInstanceState.SKIPPED
         except AirflowRescheduleException as reschedule_exception:
             self._handle_reschedule(actual_start_date, reschedule_exception, 
test_mode, session=session)
             session.commit()
@@ -1753,7 +1753,7 @@ class TaskInstance(Base, LoggingMixin):
         # Then, update ourselves so it matches the deferral request
         # Keep an eye on the logic in 
`check_and_change_state_before_execution()`
         # depending on self.next_method semantics
-        self.state = State.DEFERRED
+        self.state = TaskInstanceState.DEFERRED
         self.trigger_id = trigger_row.id
         self.next_method = defer.method_name
         self.next_kwargs = defer.kwargs or {}
@@ -1872,7 +1872,7 @@ class TaskInstance(Base, LoggingMixin):
         )
 
         # set state
-        self.state = State.UP_FOR_RESCHEDULE
+        self.state = TaskInstanceState.UP_FOR_RESCHEDULE
 
         # Decrement try_number so subsequent runs will use the same try number 
and write
         # to same log file.
@@ -1971,7 +1971,7 @@ class TaskInstance(Base, LoggingMixin):
             self.log.error("Unable to unmap task to determine if we need to 
send an alert email")
 
         if force_fail or not self.is_eligible_to_retry():
-            self.state = State.FAILED
+            self.state = TaskInstanceState.FAILED
             email_for_state = operator.attrgetter("email_on_failure")
             callbacks = task.on_failure_callback if task else None
             callback_type = "on_failure"
@@ -1980,10 +1980,10 @@ class TaskInstance(Base, LoggingMixin):
                 tis = self.get_dagrun(session).get_task_instances()
                 stop_all_tasks_in_dag(tis, session, self.task_id)
         else:
-            if self.state == State.QUEUED:
+            if self.state == TaskInstanceState.QUEUED:
                 # We increase the try_number so as to fail the task if it 
fails to start after sometime
                 self._try_number += 1
-            self.state = State.UP_FOR_RETRY
+            self.state = TaskInstanceState.UP_FOR_RETRY
             email_for_state = operator.attrgetter("email_on_retry")
             callbacks = task.on_retry_callback if task else None
             callback_type = "on_retry"
@@ -2004,7 +2004,7 @@ class TaskInstance(Base, LoggingMixin):
 
     def is_eligible_to_retry(self):
         """Is task instance is eligible for retry."""
-        if self.state == State.RESTARTING:
+        if self.state == TaskInstanceState.RESTARTING:
             # If a task is cleared when running, it goes into RESTARTING state 
and is always
             # eligible for retry
             return True
@@ -2345,7 +2345,7 @@ class TaskInstance(Base, LoggingMixin):
             'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
         )
 
-        # This function is called after changing the state from State.RUNNING,
+        # This function is called after changing the state from RUNNING,
         # so we need to subtract 1 from self.try_number here.
         current_try_number = self.try_number - 1
         additional_context: dict[str, Any] = {
@@ -2572,7 +2572,7 @@ class TaskInstance(Base, LoggingMixin):
         num_running_task_instances_query = session.query(func.count()).filter(
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.state == State.RUNNING,
+            TaskInstance.state == TaskInstanceState.RUNNING,
         )
         if same_dagrun:
             num_running_task_instances_query = 
num_running_task_instances_query.filter(
@@ -2894,7 +2894,7 @@ def _is_further_mapped_inside(operator: Operator, 
container: TaskGroup) -> bool:
 
 # State of the task instance.
 # Stores string version of the task state.
-TaskInstanceStateType = Tuple[TaskInstanceKey, str]
+TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
 
 
 class SimpleTaskInstance:
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 30daf638bd..0f242c09f3 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.sensors.base import BaseSensorOperator
 from airflow.utils.context import Context
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 
@@ -137,17 +137,17 @@ class SubDagOperator(BaseSensorOperator):
         :param execution_date: Execution date to select task instances.
         """
         with create_session() as session:
-            dag_run.state = State.RUNNING
+            dag_run.state = DagRunState.RUNNING
             session.merge(dag_run)
             failed_task_instances = (
                 session.query(TaskInstance)
                 .filter(TaskInstance.dag_id == self.subdag.dag_id)
                 .filter(TaskInstance.execution_date == execution_date)
-                .filter(TaskInstance.state.in_([State.FAILED, 
State.UPSTREAM_FAILED]))
+                .filter(TaskInstance.state.in_((TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED)))
             )
 
             for task_instance in failed_task_instances:
-                task_instance.state = State.NONE
+                task_instance.state = None
                 session.merge(task_instance)
             session.commit()
 
@@ -164,7 +164,7 @@ class SubDagOperator(BaseSensorOperator):
             dag_run = self.subdag.create_dagrun(
                 run_type=DagRunType.SCHEDULED,
                 execution_date=execution_date,
-                state=State.RUNNING,
+                state=DagRunState.RUNNING,
                 conf=self.conf,
                 external_trigger=True,
                 data_interval=data_interval,
@@ -172,13 +172,13 @@ class SubDagOperator(BaseSensorOperator):
             self.log.info("Created DagRun: %s", dag_run.run_id)
         else:
             self.log.info("Found existing DagRun: %s", dag_run.run_id)
-            if dag_run.state == State.FAILED:
+            if dag_run.state == DagRunState.FAILED:
                 self._reset_dag_run_and_task_instances(dag_run, execution_date)
 
     def poke(self, context: Context):
         execution_date = context["execution_date"]
         dag_run = self._get_dagrun(execution_date=execution_date)
-        return dag_run.state != State.RUNNING
+        return dag_run.state != DagRunState.RUNNING
 
     def post_execute(self, context, result=None):
         super().post_execute(context)
@@ -186,7 +186,7 @@ class SubDagOperator(BaseSensorOperator):
         dag_run = self._get_dagrun(execution_date=execution_date)
         self.log.info("Execution finished. State is %s", dag_run.state)
 
-        if dag_run.state != State.SUCCESS:
+        if dag_run.state != DagRunState.SUCCESS:
             raise AirflowException(f"Expected state: SUCCESS. Actual state: 
{dag_run.state}")
 
         if self.propagate_skipped_state and 
self._check_skipped_states(context):
@@ -196,9 +196,9 @@ class SubDagOperator(BaseSensorOperator):
         leaves_tis = self._get_leaves_tis(context["execution_date"])
 
         if self.propagate_skipped_state == 
SkippedStatePropagationOptions.ANY_LEAF:
-            return any(ti.state == State.SKIPPED for ti in leaves_tis)
+            return any(ti.state == TaskInstanceState.SKIPPED for ti in 
leaves_tis)
         if self.propagate_skipped_state == 
SkippedStatePropagationOptions.ALL_LEAVES:
-            return all(ti.state == State.SKIPPED for ti in leaves_tis)
+            return all(ti.state == TaskInstanceState.SKIPPED for ti in 
leaves_tis)
         raise AirflowException(
             f"Unimplemented SkippedStatePropagationOptions 
{self.propagate_skipped_state} used."
         )
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 3e222405c4..443063af8a 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -25,27 +25,26 @@ from typing import TYPE_CHECKING
 from airflow.configuration import conf
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.utils.session import find_session_idx, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
+    from airflow.models.taskinstance import TaskInstance
+
 log = logging.getLogger(__name__)
 
 
 class DummySentry:
     """Blank class for Sentry."""
 
-    @classmethod
-    def add_tagging(cls, task_instance):
+    def add_tagging(self, task_instance):
         """Blank function for tagging."""
 
-    @classmethod
-    def add_breadcrumbs(cls, task_instance, session: Session | None = None):
+    def add_breadcrumbs(self, task_instance, session: Session | None = None):
         """Blank function for breadcrumbs."""
 
-    @classmethod
-    def enrich_errors(cls, run):
+    def enrich_errors(self, run):
         """Blank function for formatting a TaskInstance._run_raw_task."""
         return run
 
@@ -137,13 +136,17 @@ if conf.getboolean("sentry", "sentry_on", fallback=False):
                 scope.set_tag("operator", task.__class__.__name__)
 
         @provide_session
-        def add_breadcrumbs(self, task_instance, session=None):
+        def add_breadcrumbs(
+            self,
+            task_instance: TaskInstance,
+            session: Session | None = None,
+        ) -> None:
             """Function to add breadcrumbs inside of a task_instance."""
             if session is None:
                 return
             dr = task_instance.get_dagrun(session)
             task_instances = dr.get_task_instances(
-                state={State.SUCCESS, State.FAILED},
+                state={TaskInstanceState.SUCCESS, TaskInstanceState.FAILED},
                 session=session,
             )
 
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py 
b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 781ab0ebaf..0a364628c7 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 
 
 class DagrunRunningDep(BaseTIDep):
@@ -31,7 +31,7 @@ class DagrunRunningDep(BaseTIDep):
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
         dr = ti.get_dagrun(session)
-        if dr.state != State.RUNNING:
+        if dr.state != DagRunState.RUNNING:
             yield self._failing_status(
                 reason=f"Task instance's dagrun was not in the 'running' state 
but in the state '{dr.state}'."
             )
diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py 
b/airflow/ti_deps/deps/not_in_retry_period_dep.py
index b3b5d4ec56..90954f29f2 100644
--- a/airflow/ti_deps/deps/not_in_retry_period_dep.py
+++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class NotInRetryPeriodDep(BaseTIDep):
@@ -38,7 +38,7 @@ class NotInRetryPeriodDep(BaseTIDep):
             )
             return
 
-        if ti.state != State.UP_FOR_RETRY:
+        if ti.state != TaskInstanceState.UP_FOR_RETRY:
             yield self._passing_status(reason="The task instance was not 
marked for retrying.")
             return
 
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py 
b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index fdc8274d90..855f04af53 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -40,7 +40,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
             XCOM_SKIPMIXIN_SKIPPED,
             SkipMixin,
         )
-        from airflow.utils.state import State
+        from airflow.utils.state import TaskInstanceState
 
         upstream = ti.task.get_direct_relatives(upstream=True)
 
@@ -87,7 +87,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
                                 reason=("Task should be skipped but the the 
past depends are not met")
                             )
                             return
-                    ti.set_state(State.SKIPPED, session)
+                    ti.set_state(TaskInstanceState.SKIPPED, session)
                     yield self._failing_status(
                         reason=f"Skipping because of previous XCom result from 
parent task {parent.task_id}"
                     )
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py 
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index b3165eb7f2..62acdbca33 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -22,7 +22,7 @@ from sqlalchemy import func
 from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class PrevDagrunDep(BaseTIDep):
@@ -107,7 +107,7 @@ class PrevDagrunDep(BaseTIDep):
             )
             return
 
-        if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
+        if previous_ti.state not in {TaskInstanceState.SKIPPED, 
TaskInstanceState.SUCCESS}:
             yield self._failing_status(
                 reason=(
                     f"depends_on_past is true for this task, but the previous 
task instance {previous_ti} "
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py 
b/airflow/ti_deps/deps/ready_to_reschedule.py
index c20ef98a08..4fca6f5538 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -22,7 +22,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class ReadyToRescheduleDep(BaseTIDep):
@@ -31,7 +31,7 @@ class ReadyToRescheduleDep(BaseTIDep):
     NAME = "Ready To Reschedule"
     IGNORABLE = True
     IS_TASK_DEP = True
-    RESCHEDULEABLE_STATES = {State.UP_FOR_RESCHEDULE, State.NONE}
+    RESCHEDULEABLE_STATES = {TaskInstanceState.UP_FOR_RESCHEDULE, None}
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
diff --git a/airflow/ti_deps/deps/task_not_running_dep.py 
b/airflow/ti_deps/deps/task_not_running_dep.py
index fd76873466..7299a4f3e2 100644
--- a/airflow/ti_deps/deps/task_not_running_dep.py
+++ b/airflow/ti_deps/deps/task_not_running_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class TaskNotRunningDep(BaseTIDep):
@@ -37,7 +37,7 @@ class TaskNotRunningDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context=None):
-        if ti.state != State.RUNNING:
+        if ti.state != TaskInstanceState.RUNNING:
             yield self._passing_status(reason="Task is not in running state.")
             return
 
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 3f35d1b21d..d3b329cb54 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -58,7 +58,7 @@ def _draw_task(
 ) -> None:
     """Draw a single task on the given parent_graph."""
     if states_by_task_id:
-        state = states_by_task_id.get(task.task_id, State.NONE)
+        state = states_by_task_id.get(task.task_id)
         color = State.color_fg(state)
         fill_color = State.color(state)
     else:
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 0c98d25503..9184a20420 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -341,7 +341,10 @@ class FileTaskHandler(logging.Handler):
         )
         log_pos = len(logs)
         messages = "".join([f"*** {x}\n" for x in messages_list])
-        end_of_log = ti.try_number != try_number or ti.state not in 
[State.RUNNING, State.DEFERRED]
+        end_of_log = ti.try_number != try_number or ti.state not in (
+            TaskInstanceState.RUNNING,
+            TaskInstanceState.DEFERRED,
+        )
         if metadata and "log_pos" in metadata:
             previous_chars = metadata["log_pos"]
             logs = logs[previous_chars:]  # Cut off previously passed log test 
as new tail
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index a4589ebca0..d93f15bb1a 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.utils.helpers import render_log_filename
 from airflow.utils.log.logging_mixin import ExternalLoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 
 class TaskLogReader:
@@ -86,7 +86,8 @@ class TaskLogReader:
                 for host, log in logs[0]:
                     yield "\n".join([host or "", log]) + "\n"
                 if "end_of_log" not in metadata or (
-                    not metadata["end_of_log"] and ti.state not in 
[State.RUNNING, State.DEFERRED]
+                    not metadata["end_of_log"]
+                    and ti.state not in (TaskInstanceState.RUNNING, 
TaskInstanceState.DEFERRED)
                 ):
                     if not logs[0]:
                         # we did not receive any logs in this loop
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index f4a8dc1a0a..fc74732acc 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -21,8 +21,7 @@ from enum import Enum
 
 
 class TaskInstanceState(str, Enum):
-    """
-    Enum that represents all possible states that a Task Instance can be in.
+    """All possible states that a Task Instance can be in.
 
     Note that None is also allowed, so always use this in a type hint with 
Optional.
     """
@@ -53,8 +52,7 @@ class TaskInstanceState(str, Enum):
 
 
 class DagRunState(str, Enum):
-    """
-    Enum that represents all possible states that a DagRun can be in.
+    """All possible states that a DagRun can be in.
 
     These are "shared" with TaskInstanceState in some parts of the code,
     so please ensure that their values always match the ones with the
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 1ca6289122..bb48f81ccc 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -87,7 +87,9 @@ def get_instance_with_map(task_instance, session):
 
 
 def get_try_count(try_number: int, state: State):
-    return try_number + 1 if state in [State.DEFERRED, 
State.UP_FOR_RESCHEDULE] else try_number
+    if state in (TaskInstanceState.DEFERRED, 
TaskInstanceState.UP_FOR_RESCHEDULE):
+        return try_number + 1
+    return try_number
 
 
 priority: list[None | TaskInstanceState] = [
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 370e15ee00..e198e390ac 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -767,7 +767,7 @@ class Airflow(AirflowBaseView):
 
             # find DAGs which have a RUNNING DagRun
             running_dags = dags_query.join(DagRun, DagModel.dag_id == 
DagRun.dag_id).where(
-                DagRun.state == State.RUNNING
+                DagRun.state == DagRunState.RUNNING
             )
 
             # find DAGs for which the latest DagRun is FAILED
@@ -778,7 +778,7 @@ class Airflow(AirflowBaseView):
             )
             subq_failed = (
                 select(DagRun.dag_id, 
func.max(DagRun.start_date).label("start_date"))
-                .where(DagRun.state == State.FAILED)
+                .where(DagRun.state == DagRunState.FAILED)
                 .group_by(DagRun.dag_id)
                 .subquery()
             )
@@ -1127,7 +1127,7 @@ class Airflow(AirflowBaseView):
         running_dag_run_query_result = (
             select(DagRun.dag_id, DagRun.run_id)
             .join(DagModel, DagModel.dag_id == DagRun.dag_id)
-            .where(DagRun.state == State.RUNNING, DagModel.is_active)
+            .where(DagRun.state == DagRunState.RUNNING, DagModel.is_active)
         )
 
         running_dag_run_query_result = 
running_dag_run_query_result.where(DagRun.dag_id.in_(filter_dag_ids))
@@ -1151,7 +1151,7 @@ class Airflow(AirflowBaseView):
             last_dag_run = (
                 select(DagRun.dag_id, 
sqla.func.max(DagRun.execution_date).label("execution_date"))
                 .join(DagModel, DagModel.dag_id == DagRun.dag_id)
-                .where(DagRun.state != State.RUNNING, DagModel.is_active)
+                .where(DagRun.state != DagRunState.RUNNING, DagModel.is_active)
                 .group_by(DagRun.dag_id)
             )
 
@@ -1856,7 +1856,7 @@ class Airflow(AirflowBaseView):
                 "Airflow administrator for assistance.".format(
                     "- This task instance already ran and had it's state 
changed manually "
                     "(e.g. cleared in the UI)<br>"
-                    if ti and ti.state == State.NONE
+                    if ti and ti.state is None
                     else ""
                 ),
             )
@@ -2169,7 +2169,7 @@ class Airflow(AirflowBaseView):
                 run_type=DagRunType.MANUAL,
                 execution_date=execution_date,
                 
data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date),
-                state=State.QUEUED,
+                state=DagRunState.QUEUED,
                 conf=run_conf,
                 external_trigger=True,
                 dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
@@ -5426,23 +5426,28 @@ class 
DagRunModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_queued(self, drs: list[DagRun]):
         """Set state to queued."""
-        return self._set_dag_runs_to_active_state(drs, State.QUEUED)
+        return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED)
 
     @action("set_running", "Set state to 'running'", "", single=False)
     @action_has_dag_edit_access
     @action_logging
     def action_set_running(self, drs: list[DagRun]):
         """Set state to running."""
-        return self._set_dag_runs_to_active_state(drs, State.RUNNING)
+        return self._set_dag_runs_to_active_state(drs, DagRunState.RUNNING)
 
     @provide_session
-    def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, 
session: Session = NEW_SESSION):
+    def _set_dag_runs_to_active_state(
+        self,
+        drs: list[DagRun],
+        state: DagRunState,
+        session: Session = NEW_SESSION,
+    ):
         """This routine only supports Running and Queued state."""
         try:
             count = 0
             for dr in 
session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in 
drs))):
                 count += 1
-                if state == State.RUNNING:
+                if state == DagRunState.RUNNING:
                     dr.start_date = timezone.utcnow()
                 dr.state = state
             session.commit()
@@ -5863,7 +5868,12 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
         return redirect(self.get_redirect())
 
     @provide_session
-    def set_task_instance_state(self, tis, target_state, session: Session = 
NEW_SESSION):
+    def set_task_instance_state(
+        self,
+        tis: Collection[TaskInstance],
+        target_state: TaskInstanceState,
+        session: Session = NEW_SESSION,
+    ) -> None:
         """Set task instance state."""
         try:
             count = len(tis)
@@ -5879,7 +5889,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_running(self, tis):
         """Set state to 'running'."""
-        self.set_task_instance_state(tis, State.RUNNING)
+        self.set_task_instance_state(tis, TaskInstanceState.RUNNING)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5888,7 +5898,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_failed(self, tis):
         """Set state to 'failed'."""
-        self.set_task_instance_state(tis, State.FAILED)
+        self.set_task_instance_state(tis, TaskInstanceState.FAILED)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5897,7 +5907,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_success(self, tis):
         """Set state to 'success'."""
-        self.set_task_instance_state(tis, State.SUCCESS)
+        self.set_task_instance_state(tis, TaskInstanceState.SUCCESS)
         self.update_redirect()
         return redirect(self.get_redirect())
 
@@ -5906,7 +5916,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @action_logging
     def action_set_retry(self, tis):
         """Set state to 'up_for_retry'."""
-        self.set_task_instance_state(tis, State.UP_FOR_RETRY)
+        self.set_task_instance_state(tis, TaskInstanceState.UP_FOR_RETRY)
         self.update_redirect()
         return redirect(self.get_redirect())
 
diff --git a/dev/perf/scheduler_dag_execution_timing.py 
b/dev/perf/scheduler_dag_execution_timing.py
index 613a929e9e..db7f4f2e8d 100755
--- a/dev/perf/scheduler_dag_execution_timing.py
+++ b/dev/perf/scheduler_dag_execution_timing.py
@@ -63,7 +63,7 @@ class ShortCircuitExecutorMixin:
         Change the state of scheduler by waiting till the tasks is complete
         and then shut down the scheduler after the task is complete
         """
-        from airflow.utils.state import State
+        from airflow.utils.state import TaskInstanceState
 
         super().change_state(key, state, info=info)
 
@@ -83,7 +83,7 @@ class ShortCircuitExecutorMixin:
             run = list(airflow.models.DagRun.find(dag_id=dag_id, 
execution_date=execution_date))[0]
             self.dags_to_watch[dag_id].runs[execution_date] = run
 
-        if run and all(t.state == State.SUCCESS for t in 
run.get_task_instances()):
+        if run and all(t.state == TaskInstanceState.SUCCESS for t in 
run.get_task_instances()):
             self.dags_to_watch[dag_id].runs.pop(execution_date)
             self.dags_to_watch[dag_id].waiting_for -= 1
 
@@ -156,7 +156,7 @@ def create_dag_runs(dag, num_runs, session):
     Create  `num_runs` of dag runs for sub-sequent schedules
     """
     from airflow.utils import timezone
-    from airflow.utils.state import State
+    from airflow.utils.state import DagRunState
 
     try:
         from airflow.utils.types import DagRunType
@@ -175,7 +175,7 @@ def create_dag_runs(dag, num_runs, session):
             run_id=f"{id_prefix}{logical_date.isoformat()}",
             execution_date=logical_date,
             start_date=timezone.utcnow(),
-            state=State.RUNNING,
+            state=DagRunState.RUNNING,
             external_trigger=False,
             session=session,
         )

Reply via email to