This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 84bd0f0bef0 fix mypy errors in providers/standard/ (#57266)
84bd0f0bef0 is described below

commit 84bd0f0bef0a131d097a08de864719e2c6a4cf00
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 18:45:47 2025 +0530

    fix mypy errors in providers/standard/ (#57266)
---
 .../providers/standard/operators/latest_only.py    |  4 +--
 .../providers/standard/sensors/external_task.py    |  2 +-
 .../providers/standard/triggers/external_task.py   |  2 +-
 .../providers/standard/utils/sensor_helper.py      | 24 ++++++++++----
 .../tests/unit/standard/operators/test_hitl.py     | 38 +++++++++++++---------
 5 files changed, 44 insertions(+), 26 deletions(-)

diff --git 
a/providers/standard/src/airflow/providers/standard/operators/latest_only.py 
b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
index d2263ff66bc..9cf75573dcf 100644
--- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
+++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
@@ -88,9 +88,9 @@ class LatestOnlyOperator(BaseBranchOperator):
     def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] 
| None:
         dagrun_date: DateTime
         if AIRFLOW_V_3_0_PLUS:
-            dagrun_date = dag_run.logical_date or dag_run.run_after
+            dagrun_date = dag_run.logical_date or dag_run.run_after  # type: 
ignore[assignment]
         else:
-            dagrun_date = dag_run.logical_date
+            dagrun_date = dag_run.logical_date  # type: ignore[assignment]
 
         from airflow.timetables.base import DataInterval, TimeRestriction
 
diff --git 
a/providers/standard/src/airflow/providers/standard/sensors/external_task.py 
b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
index ee809caffc1..024c14616c9 100644
--- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
+++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
@@ -488,7 +488,7 @@ class ExternalTaskSensor(BaseSensorOperator):
 
         self._has_checked_existence = True
 
-    def get_count(self, dttm_filter, session, states) -> int:
+    def get_count(self, dttm_filter: list[datetime.datetime], session: 
Session, states: list[str]) -> int:
         """
         Get the count of records against dttm filter and states.
 
diff --git 
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py 
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
index 0085bb33b86..7072820d499 100644
--- 
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
+++ 
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
@@ -283,4 +283,4 @@ class DagStateTrigger(BaseTrigger):
                 )
                 .scalar()
             )
-            return typing.cast("int", count)
+            return typing.cast("int", count or 0)
diff --git 
a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py 
b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
index 09a1765ad59..7b8229ae926 100644
--- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
@@ -27,7 +27,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
-    from sqlalchemy.sql import Executable
+    from sqlalchemy.sql import Select
 
 
 @provide_session
@@ -59,6 +59,7 @@ def _get_count(
             session.scalar(
                 _count_stmt(TI, states, dttm_filter, 
external_dag_id).where(TI.task_id.in_(external_task_ids))
             )
+            or 0
         ) / len(external_task_ids)
     elif external_task_group_id:
         external_task_group_task_ids = _get_external_task_group_task_ids(
@@ -68,20 +69,25 @@ def _get_count(
             count = 0
         else:
             count = (
-                session.scalar(
-                    _count_stmt(TI, states, dttm_filter, 
external_dag_id).where(
-                        tuple_(TI.task_id, 
TI.map_index).in_(external_task_group_task_ids)
+                (
+                    session.scalar(
+                        _count_stmt(TI, states, dttm_filter, 
external_dag_id).where(
+                            tuple_(TI.task_id, 
TI.map_index).in_(external_task_group_task_ids)
+                        )
                     )
+                    or 0
                 )
                 / len(external_task_group_task_ids)
                 * len(dttm_filter)
             )
     else:
-        count = session.scalar(_count_stmt(DR, states, dttm_filter, 
external_dag_id))
+        count = session.scalar(_count_stmt(DR, states, dttm_filter, 
external_dag_id)) or 0
     return cast("int", count)
 
 
-def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
+def _count_stmt(
+    model: type[DagRun] | type[TaskInstance], states: list[str], dttm_filter: 
list[Any], external_dag_id: str
+) -> Select[tuple[int]]:
     """
     Get the count of records against dttm filter and states.
 
@@ -97,7 +103,9 @@ def _count_stmt(model, states, dttm_filter, external_dag_id) 
-> Executable:
     )
 
 
-def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, 
external_dag_id, session):
+def _get_external_task_group_task_ids(
+    dttm_filter: list[Any], external_task_group_id: str, external_dag_id: str, 
session: Session
+) -> list[tuple[str, int]]:
     """
     Get the count of records against dttm filter and states.
 
@@ -107,6 +115,8 @@ def _get_external_task_group_task_ids(dttm_filter, 
external_task_group_id, exter
     :param session: airflow session object
     """
     refreshed_dag_info = SerializedDagModel.get_dag(external_dag_id, 
session=session)
+    if not refreshed_dag_info:
+        return [(external_task_group_id, -1)]
     task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
 
     if task_group:
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py 
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 05f4fc4236d..cead25507bd 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -81,7 +81,7 @@ def hitl_task_and_ti_for_generating_link(dag_maker: DagMaker) 
-> tuple[HITLOpera
 
 
 @pytest.fixture
-def get_context_from_model_ti(mock_supervisor_comms):
+def get_context_from_model_ti(mock_supervisor_comms: Any) -> Any:
     def _get_context(ti: TaskInstance) -> Context:
         from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
             DagRun as DRDataModel,
@@ -211,6 +211,7 @@ class TestHITLOperator:
         ti = dag_maker.run_ti(task.task_id, dr)
 
         hitl_detail_model = 
session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti.id))
+        assert hitl_detail_model is not None
         assert hitl_detail_model.ti_id == ti.id
         assert hitl_detail_model.subject == "This is subject"
         assert hitl_detail_model.options == ["1", "2", "3", "4", "5"]
@@ -229,6 +230,7 @@ class TestHITLOperator:
         registered_trigger = session.scalar(
             select(Trigger).where(Trigger.classpath == 
"airflow.providers.standard.triggers.hitl.HITLTrigger")
         )
+        assert registered_trigger is not None
         assert registered_trigger.kwargs == {
             "ti_id": ti.id,
             "options": ["1", "2", "3", "4", "5"],
@@ -247,7 +249,9 @@ class TestHITLOperator:
             (None, {}),
         ],
     )
-    def test_serialzed_params(self, input_params, expected_params: dict[str, 
Any]) -> None:
+    def test_serialzed_params(
+        self, input_params: ParamsDict | dict[str, Any] | None, 
expected_params: dict[str, Any]
+    ) -> None:
         hitl_op = HITLOperator(
             task_id="hitl_test",
             subject="This is subject",
@@ -295,7 +299,7 @@ class TestHITLOperator:
     def test_process_trigger_event_error(
         self,
         event: dict[str, Any],
-        expected_exception,
+        expected_exception: type[Exception],
     ) -> None:
         hitl_op = HITLOperator(
             task_id="hitl_test",
@@ -422,7 +426,7 @@ class TestHITLOperator:
     )
     def test_generate_link_to_ui_with_invalid_input(
         self,
-        options: list[str] | None,
+        options: list[Any] | None,
         params_input: dict[str, Any] | None,
         expected_err_msg: str,
         hitl_task_and_ti_for_generating_link: tuple[HITLOperator, 
TaskInstance],
@@ -486,7 +490,7 @@ class TestApprovalOperator:
         }
 
     def test_execute_complete_with_downstream_tasks(
-        self, dag_maker: DagMaker, get_context_from_model_ti
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
     ) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             hitl_op = ApprovalOperator(
@@ -510,7 +514,7 @@ class TestApprovalOperator:
         assert set(exc_info.value.tasks) == {"op1"}
 
     def test_execute_complete_with_fail_on_reject_set_to_true(
-        self, dag_maker: DagMaker, get_context_from_model_ti
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
     ) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             hitl_op = ApprovalOperator(task_id="hitl_test", subject="This is 
subject", fail_on_reject=True)
@@ -568,7 +572,7 @@ class TestHITLEntryOperator:
 
 
 class TestHITLBranchOperator:
-    def test_execute_complete(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
+    def test_execute_complete(self, dag_maker: DagMaker, 
get_context_from_model_ti: Any) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             branch_op = HITLBranchOperator(
                 task_id="make_choice",
@@ -593,7 +597,7 @@ class TestHITLBranchOperator:
         assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in 
range(2, 6))
 
     def test_execute_complete_with_multiple_branches(
-        self, dag_maker: DagMaker, get_context_from_model_ti
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
     ) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             branch_op = HITLBranchOperator(
@@ -621,7 +625,9 @@ class TestHITLBranchOperator:
             )
         assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in 
range(4, 6))
 
-    def test_mapping_applies_for_single_choice(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
+    def test_mapping_applies_for_single_choice(
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
+    ) -> None:
         # ["Approve"]; map -> "publish"
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -648,7 +654,7 @@ class TestHITLBranchOperator:
         # checks to see that the "archive" task was skipped
         assert set(exc.value.tasks) == {("archive", -1)}
 
-    def test_mapping_with_multiple_choices(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
+    def test_mapping_with_multiple_choices(self, dag_maker: DagMaker, 
get_context_from_model_ti: Any) -> None:
         # multiple=True; mapping applied per option; no dedup implied
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -680,7 +686,9 @@ class TestHITLBranchOperator:
         # publish + keep chosen → only "other" skipped
         assert set(exc.value.tasks) == {("other", -1)}
 
-    def test_fallback_to_option_when_not_mapped(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
+    def test_fallback_to_option_when_not_mapped(
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
+    ) -> None:
         # No mapping: option must match downstream task_id
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -706,8 +714,8 @@ class TestHITLBranchOperator:
         assert set(exc.value.tasks) == {("branch_1", -1)}
 
     def test_error_if_mapped_branch_not_direct_downstream(
-        self, dag_maker: DagMaker, get_context_from_model_ti
-    ):
+        self, dag_maker: DagMaker, get_context_from_model_ti: Any
+    ) -> None:
         # Don't add the mapped task downstream → expect a clean error
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -733,7 +741,7 @@ class TestHITLBranchOperator:
             )
 
     @pytest.mark.parametrize("bad", [123, ["publish"], {"x": "y"}, b"publish"])
-    def test_options_mapping_non_string_value_raises(self, bad):
+    def test_options_mapping_non_string_value_raises(self, bad: Any) -> None:
         with pytest.raises(ValueError, match=r"values must be strings 
\(task_ids\)"):
             HITLBranchOperator(
                 task_id="choose",
@@ -742,7 +750,7 @@ class TestHITLBranchOperator:
                 options_mapping={"Approve": bad},
             )
 
-    def test_options_mapping_key_not_in_options_raises(self):
+    def test_options_mapping_key_not_in_options_raises(self) -> None:
         with pytest.raises(ValueError, match="contains keys that are not in 
`options`"):
             HITLBranchOperator(
                 task_id="choose",

Reply via email to