This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 84bd0f0bef0 fix mypy errors in providers/standard/ (#57266)
84bd0f0bef0 is described below
commit 84bd0f0bef0a131d097a08de864719e2c6a4cf00
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 18:45:47 2025 +0530
fix mypy errors in providers/standard/ (#57266)
---
.../providers/standard/operators/latest_only.py | 4 +--
.../providers/standard/sensors/external_task.py | 2 +-
.../providers/standard/triggers/external_task.py | 2 +-
.../providers/standard/utils/sensor_helper.py | 24 ++++++++++----
.../tests/unit/standard/operators/test_hitl.py | 38 +++++++++++++---------
5 files changed, 44 insertions(+), 26 deletions(-)
diff --git
a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
index d2263ff66bc..9cf75573dcf 100644
--- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
+++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
@@ -88,9 +88,9 @@ class LatestOnlyOperator(BaseBranchOperator):
def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime]
| None:
dagrun_date: DateTime
if AIRFLOW_V_3_0_PLUS:
- dagrun_date = dag_run.logical_date or dag_run.run_after
+ dagrun_date = dag_run.logical_date or dag_run.run_after # type:
ignore[assignment]
else:
- dagrun_date = dag_run.logical_date
+ dagrun_date = dag_run.logical_date # type: ignore[assignment]
from airflow.timetables.base import DataInterval, TimeRestriction
diff --git
a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
index ee809caffc1..024c14616c9 100644
--- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
+++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
@@ -488,7 +488,7 @@ class ExternalTaskSensor(BaseSensorOperator):
self._has_checked_existence = True
- def get_count(self, dttm_filter, session, states) -> int:
+ def get_count(self, dttm_filter: list[datetime.datetime], session:
Session, states: list[str]) -> int:
"""
Get the count of records against dttm filter and states.
diff --git
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
index 0085bb33b86..7072820d499 100644
---
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
+++
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
@@ -283,4 +283,4 @@ class DagStateTrigger(BaseTrigger):
)
.scalar()
)
- return typing.cast("int", count)
+ return typing.cast("int", count or 0)
diff --git
a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
index 09a1765ad59..7b8229ae926 100644
--- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
@@ -27,7 +27,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
from sqlalchemy.orm import Session
- from sqlalchemy.sql import Executable
+ from sqlalchemy.sql import Select
@provide_session
@@ -59,6 +59,7 @@ def _get_count(
session.scalar(
_count_stmt(TI, states, dttm_filter,
external_dag_id).where(TI.task_id.in_(external_task_ids))
)
+ or 0
) / len(external_task_ids)
elif external_task_group_id:
external_task_group_task_ids = _get_external_task_group_task_ids(
@@ -68,20 +69,25 @@ def _get_count(
count = 0
else:
count = (
- session.scalar(
- _count_stmt(TI, states, dttm_filter,
external_dag_id).where(
- tuple_(TI.task_id,
TI.map_index).in_(external_task_group_task_ids)
+ (
+ session.scalar(
+ _count_stmt(TI, states, dttm_filter,
external_dag_id).where(
+ tuple_(TI.task_id,
TI.map_index).in_(external_task_group_task_ids)
+ )
)
+ or 0
)
/ len(external_task_group_task_ids)
* len(dttm_filter)
)
else:
- count = session.scalar(_count_stmt(DR, states, dttm_filter,
external_dag_id))
+ count = session.scalar(_count_stmt(DR, states, dttm_filter,
external_dag_id)) or 0
return cast("int", count)
-def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
+def _count_stmt(
+ model: type[DagRun] | type[TaskInstance], states: list[str], dttm_filter:
list[Any], external_dag_id: str
+) -> Select[tuple[int]]:
"""
Get the count of records against dttm filter and states.
@@ -97,7 +103,9 @@ def _count_stmt(model, states, dttm_filter, external_dag_id)
-> Executable:
)
-def _get_external_task_group_task_ids(dttm_filter, external_task_group_id,
external_dag_id, session):
+def _get_external_task_group_task_ids(
+ dttm_filter: list[Any], external_task_group_id: str, external_dag_id: str,
session: Session
+) -> list[tuple[str, int]]:
"""
Get the count of records against dttm filter and states.
@@ -107,6 +115,8 @@ def _get_external_task_group_task_ids(dttm_filter,
external_task_group_id, exter
:param session: airflow session object
"""
refreshed_dag_info = SerializedDagModel.get_dag(external_dag_id,
session=session)
+ if not refreshed_dag_info:
+ return [(external_task_group_id, -1)]
task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
if task_group:
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 05f4fc4236d..cead25507bd 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -81,7 +81,7 @@ def hitl_task_and_ti_for_generating_link(dag_maker: DagMaker)
-> tuple[HITLOpera
@pytest.fixture
-def get_context_from_model_ti(mock_supervisor_comms):
+def get_context_from_model_ti(mock_supervisor_comms: Any) -> Any:
def _get_context(ti: TaskInstance) -> Context:
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun as DRDataModel,
@@ -211,6 +211,7 @@ class TestHITLOperator:
ti = dag_maker.run_ti(task.task_id, dr)
hitl_detail_model =
session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti.id))
+ assert hitl_detail_model is not None
assert hitl_detail_model.ti_id == ti.id
assert hitl_detail_model.subject == "This is subject"
assert hitl_detail_model.options == ["1", "2", "3", "4", "5"]
@@ -229,6 +230,7 @@ class TestHITLOperator:
registered_trigger = session.scalar(
select(Trigger).where(Trigger.classpath ==
"airflow.providers.standard.triggers.hitl.HITLTrigger")
)
+ assert registered_trigger is not None
assert registered_trigger.kwargs == {
"ti_id": ti.id,
"options": ["1", "2", "3", "4", "5"],
@@ -247,7 +249,9 @@ class TestHITLOperator:
(None, {}),
],
)
- def test_serialzed_params(self, input_params, expected_params: dict[str,
Any]) -> None:
+ def test_serialzed_params(
+ self, input_params: ParamsDict | dict[str, Any] | None,
expected_params: dict[str, Any]
+ ) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
subject="This is subject",
@@ -295,7 +299,7 @@ class TestHITLOperator:
def test_process_trigger_event_error(
self,
event: dict[str, Any],
- expected_exception,
+ expected_exception: type[Exception],
) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
@@ -422,7 +426,7 @@ class TestHITLOperator:
)
def test_generate_link_to_ui_with_invalid_input(
self,
- options: list[str] | None,
+ options: list[Any] | None,
params_input: dict[str, Any] | None,
expected_err_msg: str,
hitl_task_and_ti_for_generating_link: tuple[HITLOperator,
TaskInstance],
@@ -486,7 +490,7 @@ class TestApprovalOperator:
}
def test_execute_complete_with_downstream_tasks(
- self, dag_maker: DagMaker, get_context_from_model_ti
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
) -> None:
with dag_maker("hitl_test_dag", serialized=True):
hitl_op = ApprovalOperator(
@@ -510,7 +514,7 @@ class TestApprovalOperator:
assert set(exc_info.value.tasks) == {"op1"}
def test_execute_complete_with_fail_on_reject_set_to_true(
- self, dag_maker: DagMaker, get_context_from_model_ti
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
) -> None:
with dag_maker("hitl_test_dag", serialized=True):
hitl_op = ApprovalOperator(task_id="hitl_test", subject="This is
subject", fail_on_reject=True)
@@ -568,7 +572,7 @@ class TestHITLEntryOperator:
class TestHITLBranchOperator:
- def test_execute_complete(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
+ def test_execute_complete(self, dag_maker: DagMaker,
get_context_from_model_ti: Any) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
task_id="make_choice",
@@ -593,7 +597,7 @@ class TestHITLBranchOperator:
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in
range(2, 6))
def test_execute_complete_with_multiple_branches(
- self, dag_maker: DagMaker, get_context_from_model_ti
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
@@ -621,7 +625,9 @@ class TestHITLBranchOperator:
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in
range(4, 6))
- def test_mapping_applies_for_single_choice(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
+ def test_mapping_applies_for_single_choice(
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
+ ) -> None:
# ["Approve"]; map -> "publish"
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -648,7 +654,7 @@ class TestHITLBranchOperator:
# checks to see that the "archive" task was skipped
assert set(exc.value.tasks) == {("archive", -1)}
- def test_mapping_with_multiple_choices(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
+ def test_mapping_with_multiple_choices(self, dag_maker: DagMaker,
get_context_from_model_ti: Any) -> None:
# multiple=True; mapping applied per option; no dedup implied
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -680,7 +686,9 @@ class TestHITLBranchOperator:
# publish + keep chosen → only "other" skipped
assert set(exc.value.tasks) == {("other", -1)}
- def test_fallback_to_option_when_not_mapped(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
+ def test_fallback_to_option_when_not_mapped(
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
+ ) -> None:
# No mapping: option must match downstream task_id
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -706,8 +714,8 @@ class TestHITLBranchOperator:
assert set(exc.value.tasks) == {("branch_1", -1)}
def test_error_if_mapped_branch_not_direct_downstream(
- self, dag_maker: DagMaker, get_context_from_model_ti
- ):
+ self, dag_maker: DagMaker, get_context_from_model_ti: Any
+ ) -> None:
# Don't add the mapped task downstream → expect a clean error
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -733,7 +741,7 @@ class TestHITLBranchOperator:
)
@pytest.mark.parametrize("bad", [123, ["publish"], {"x": "y"}, b"publish"])
- def test_options_mapping_non_string_value_raises(self, bad):
+ def test_options_mapping_non_string_value_raises(self, bad: Any) -> None:
with pytest.raises(ValueError, match=r"values must be strings
\(task_ids\)"):
HITLBranchOperator(
task_id="choose",
@@ -742,7 +750,7 @@ class TestHITLBranchOperator:
options_mapping={"Approve": bad},
)
- def test_options_mapping_key_not_in_options_raises(self):
+ def test_options_mapping_key_not_in_options_raises(self) -> None:
with pytest.raises(ValueError, match="contains keys that are not in
`options`"):
HITLBranchOperator(
task_id="choose",