This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 947a6d3367 Fix MappedOperator property types (#37870)
947a6d3367 is described below
commit 947a6d336784b0ee0e72e48d540dbec3f7eca095
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Mar 4 14:10:18 2024 +0800
Fix MappedOperator property types (#37870)
---
airflow/decorators/base.py | 1 -
airflow/models/baseoperator.py | 3 +-
airflow/models/mappedoperator.py | 53 +++++++++++++++++------------
airflow/serialization/serialized_objects.py | 1 -
4 files changed, 33 insertions(+), 25 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 93c403e0bb..51ebbce29c 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -459,7 +459,6 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values
go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
- map_index_template=partial_kwargs.pop("map_index_template", None),
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 18d596bc4a..c563b0e63f 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -196,7 +196,8 @@ class _PartialDescriptor:
return self.class_method.__get__(cls, cls)
-_PARTIAL_DEFAULTS = {
+_PARTIAL_DEFAULTS: dict[str, Any] = {
+ "map_index_template": None,
"owner": DEFAULT_OWNER,
"trigger_rule": DEFAULT_TRIGGER_RULE,
"depends_on_past": False,
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 1e18249a22..b2e85bbb7a 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -58,6 +58,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY
if TYPE_CHECKING:
import datetime
+ from typing import List
import jinja2 # Slow import.
import pendulum
@@ -83,6 +84,8 @@ if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
+ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback,
List[TaskStateChangeCallback]]
+
ValidationSource = Union[Literal["expand"], Literal["partial"]]
@@ -211,7 +214,6 @@ class OperatorPartial:
expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
- map_index_template=partial_kwargs.pop("map_index_template", None),
params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
@@ -281,7 +283,6 @@ class MappedOperator(AbstractOperator):
end_date: pendulum.DateTime | None
upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
- map_index_template: str | None
_disallow_kwargs_override: bool
"""Whether execution fails if ``expand_input`` has duplicates to
``partial_kwargs``.
@@ -392,6 +393,14 @@ class MappedOperator(AbstractOperator):
def email(self) -> None | str | Iterable[str]:
return self.partial_kwargs.get("email")
+ @property
+ def map_index_template(self) -> None | str:
+ return self.partial_kwargs.get("map_index_template")
+
+ @map_index_template.setter
+ def map_index_template(self, value: str | None) -> None:
+ self.partial_kwargs["map_index_template"] = value
+
@property
def trigger_rule(self) -> TriggerRule:
return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
@@ -453,11 +462,11 @@ class MappedOperator(AbstractOperator):
self.partial_kwargs["wait_for_downstream"] = value
@property
- def retries(self) -> int | None:
+ def retries(self) -> int:
return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
@retries.setter
- def retries(self, value: int | None) -> None:
+ def retries(self, value: int) -> None:
self.partial_kwargs["retries"] = value
@property
@@ -465,7 +474,7 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
@queue.setter
- def queue(self, value: str | None) -> None:
+ def queue(self, value: str) -> None:
self.partial_kwargs["queue"] = value
@property
@@ -473,15 +482,15 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
@pool.setter
- def pool(self, value: str | None) -> None:
+ def pool(self, value: str) -> None:
self.partial_kwargs["pool"] = value
@property
- def pool_slots(self) -> str | None:
+ def pool_slots(self) -> int:
return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
@pool_slots.setter
- def pool_slots(self, value: str | None) -> None:
+ def pool_slots(self, value: int) -> None:
self.partial_kwargs["pool_slots"] = value
@property
@@ -505,7 +514,7 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
@retry_delay.setter
- def retry_delay(self, value: datetime.timedelta | None) -> None:
+ def retry_delay(self, value: datetime.timedelta) -> None:
self.partial_kwargs["retry_delay"] = value
@property
@@ -513,7 +522,7 @@ class MappedOperator(AbstractOperator):
return bool(self.partial_kwargs.get("retry_exponential_backoff"))
@retry_exponential_backoff.setter
- def retry_exponential_backoff(self, value: bool | None) -> None:
+ def retry_exponential_backoff(self, value: bool) -> None:
self.partial_kwargs["retry_exponential_backoff"] = value
@property
@@ -521,7 +530,7 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("priority_weight",
DEFAULT_PRIORITY_WEIGHT)
@priority_weight.setter
- def priority_weight(self, value: int | None) -> None:
+ def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value
@property
@@ -529,7 +538,7 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
@weight_rule.setter
- def weight_rule(self, value: str | None) -> None:
+ def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value
@property
@@ -561,43 +570,43 @@ class MappedOperator(AbstractOperator):
return self.partial_kwargs.get("resources")
@property
- def on_execute_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_execute_callback")
@on_execute_callback.setter
- def on_execute_callback(self, value: TaskStateChangeCallback | None) ->
None:
+ def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) ->
None:
self.partial_kwargs["on_execute_callback"] = value
@property
- def on_failure_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_failure_callback")
@on_failure_callback.setter
- def on_failure_callback(self, value: TaskStateChangeCallback | None) ->
None:
+ def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) ->
None:
self.partial_kwargs["on_failure_callback"] = value
@property
- def on_retry_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_retry_callback")
@on_retry_callback.setter
- def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
+ def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) ->
None:
self.partial_kwargs["on_retry_callback"] = value
@property
- def on_success_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_success_callback")
@on_success_callback.setter
- def on_success_callback(self, value: TaskStateChangeCallback | None) ->
None:
+ def on_success_callback(self, value: TaskStateChangeCallbackAttrType) ->
None:
self.partial_kwargs["on_success_callback"] = value
@property
- def on_skipped_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_skipped_callback")
@on_skipped_callback.setter
- def on_skipped_callback(self, value: TaskStateChangeCallback | None) ->
None:
+ def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) ->
None:
self.partial_kwargs["on_skipped_callback"] = value
@property
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 82c4a6bccc..f2d4aed890 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1130,7 +1130,6 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
task_group=None,
start_date=None,
end_date=None,
- map_index_template=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
)