This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 947a6d3367 Fix MappedOperator property types (#37870)
947a6d3367 is described below

commit 947a6d336784b0ee0e72e48d540dbec3f7eca095
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Mar 4 14:10:18 2024 +0800

    Fix MappedOperator property types (#37870)
---
 airflow/decorators/base.py                  |  1 -
 airflow/models/baseoperator.py              |  3 +-
 airflow/models/mappedoperator.py            | 53 +++++++++++++++++------------
 airflow/serialization/serialized_objects.py |  1 -
 4 files changed, 33 insertions(+), 25 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 93c403e0bb..51ebbce29c 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -459,7 +459,6 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             expand_input=EXPAND_INPUT_EMPTY,  # Don't use this; mapped values 
go to op_kwargs_expand_input.
             partial_kwargs=partial_kwargs,
             task_id=task_id,
-            map_index_template=partial_kwargs.pop("map_index_template", None),
             params=partial_params,
             deps=MappedOperator.deps_for(self.operator_class),
             operator_extra_links=self.operator_class.operator_extra_links,
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 18d596bc4a..c563b0e63f 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -196,7 +196,8 @@ class _PartialDescriptor:
         return self.class_method.__get__(cls, cls)
 
 
-_PARTIAL_DEFAULTS = {
+_PARTIAL_DEFAULTS: dict[str, Any] = {
+    "map_index_template": None,
     "owner": DEFAULT_OWNER,
     "trigger_rule": DEFAULT_TRIGGER_RULE,
     "depends_on_past": False,
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 1e18249a22..b2e85bbb7a 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -58,6 +58,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY
 
 if TYPE_CHECKING:
     import datetime
+    from typing import List
 
     import jinja2  # Slow import.
     import pendulum
@@ -83,6 +84,8 @@ if TYPE_CHECKING:
     from airflow.utils.task_group import TaskGroup
     from airflow.utils.trigger_rule import TriggerRule
 
+    TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, 
List[TaskStateChangeCallback]]
+
 ValidationSource = Union[Literal["expand"], Literal["partial"]]
 
 
@@ -211,7 +214,6 @@ class OperatorPartial:
             expand_input=expand_input,
             partial_kwargs=partial_kwargs,
             task_id=task_id,
-            map_index_template=partial_kwargs.pop("map_index_template", None),
             params=self.params,
             deps=MappedOperator.deps_for(self.operator_class),
             operator_extra_links=self.operator_class.operator_extra_links,
@@ -281,7 +283,6 @@ class MappedOperator(AbstractOperator):
     end_date: pendulum.DateTime | None
     upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
     downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
-    map_index_template: str | None
 
     _disallow_kwargs_override: bool
     """Whether execution fails if ``expand_input`` has duplicates to 
``partial_kwargs``.
@@ -392,6 +393,14 @@ class MappedOperator(AbstractOperator):
     def email(self) -> None | str | Iterable[str]:
         return self.partial_kwargs.get("email")
 
+    @property
+    def map_index_template(self) -> None | str:
+        return self.partial_kwargs.get("map_index_template")
+
+    @map_index_template.setter
+    def map_index_template(self, value: str | None) -> None:
+        self.partial_kwargs["map_index_template"] = value
+
     @property
     def trigger_rule(self) -> TriggerRule:
         return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
@@ -453,11 +462,11 @@ class MappedOperator(AbstractOperator):
         self.partial_kwargs["wait_for_downstream"] = value
 
     @property
-    def retries(self) -> int | None:
+    def retries(self) -> int:
         return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
 
     @retries.setter
-    def retries(self, value: int | None) -> None:
+    def retries(self, value: int) -> None:
         self.partial_kwargs["retries"] = value
 
     @property
@@ -465,7 +474,7 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
 
     @queue.setter
-    def queue(self, value: str | None) -> None:
+    def queue(self, value: str) -> None:
         self.partial_kwargs["queue"] = value
 
     @property
@@ -473,15 +482,15 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
 
     @pool.setter
-    def pool(self, value: str | None) -> None:
+    def pool(self, value: str) -> None:
         self.partial_kwargs["pool"] = value
 
     @property
-    def pool_slots(self) -> str | None:
+    def pool_slots(self) -> int:
         return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
 
     @pool_slots.setter
-    def pool_slots(self, value: str | None) -> None:
+    def pool_slots(self, value: int) -> None:
         self.partial_kwargs["pool_slots"] = value
 
     @property
@@ -505,7 +514,7 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
 
     @retry_delay.setter
-    def retry_delay(self, value: datetime.timedelta | None) -> None:
+    def retry_delay(self, value: datetime.timedelta) -> None:
         self.partial_kwargs["retry_delay"] = value
 
     @property
@@ -513,7 +522,7 @@ class MappedOperator(AbstractOperator):
         return bool(self.partial_kwargs.get("retry_exponential_backoff"))
 
     @retry_exponential_backoff.setter
-    def retry_exponential_backoff(self, value: bool | None) -> None:
+    def retry_exponential_backoff(self, value: bool) -> None:
         self.partial_kwargs["retry_exponential_backoff"] = value
 
     @property
@@ -521,7 +530,7 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("priority_weight", 
DEFAULT_PRIORITY_WEIGHT)
 
     @priority_weight.setter
-    def priority_weight(self, value: int | None) -> None:
+    def priority_weight(self, value: int) -> None:
         self.partial_kwargs["priority_weight"] = value
 
     @property
@@ -529,7 +538,7 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
 
     @weight_rule.setter
-    def weight_rule(self, value: str | None) -> None:
+    def weight_rule(self, value: str) -> None:
         self.partial_kwargs["weight_rule"] = value
 
     @property
@@ -561,43 +570,43 @@ class MappedOperator(AbstractOperator):
         return self.partial_kwargs.get("resources")
 
     @property
-    def on_execute_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+    def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
         return self.partial_kwargs.get("on_execute_callback")
 
     @on_execute_callback.setter
-    def on_execute_callback(self, value: TaskStateChangeCallback | None) -> 
None:
+    def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> 
None:
         self.partial_kwargs["on_execute_callback"] = value
 
     @property
-    def on_failure_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+    def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
         return self.partial_kwargs.get("on_failure_callback")
 
     @on_failure_callback.setter
-    def on_failure_callback(self, value: TaskStateChangeCallback | None) -> 
None:
+    def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> 
None:
         self.partial_kwargs["on_failure_callback"] = value
 
     @property
-    def on_retry_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+    def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
         return self.partial_kwargs.get("on_retry_callback")
 
     @on_retry_callback.setter
-    def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
+    def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> 
None:
         self.partial_kwargs["on_retry_callback"] = value
 
     @property
-    def on_success_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+    def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
         return self.partial_kwargs.get("on_success_callback")
 
     @on_success_callback.setter
-    def on_success_callback(self, value: TaskStateChangeCallback | None) -> 
None:
+    def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> 
None:
         self.partial_kwargs["on_success_callback"] = value
 
     @property
-    def on_skipped_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+    def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
         return self.partial_kwargs.get("on_skipped_callback")
 
     @on_skipped_callback.setter
-    def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> 
None:
+    def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> 
None:
         self.partial_kwargs["on_skipped_callback"] = value
 
     @property
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 82c4a6bccc..f2d4aed890 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1130,7 +1130,6 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 task_group=None,
                 start_date=None,
                 end_date=None,
-                map_index_template=None,
                 
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
                 expand_input_attr=encoded_op["_expand_input_attr"],
             )

Reply via email to