ashb commented on code in PR #45627:
URL: https://github.com/apache/airflow/pull/45627#discussion_r1914641356


##########
task_sdk/src/airflow/sdk/definitions/mappedoperator.py:
##########
@@ -0,0 +1,898 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import collections.abc
+import contextlib
+import copy
+import warnings
+from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar, Union
+
+import attr
+import methodtools
+
+from airflow.exceptions import UnmappableOperator
+from airflow.models.abstractoperator import NotMapped
+from airflow.models.expandinput import (
+    DictOfListsExpandInput,
+    ListOfDictsExpandInput,
+    is_mappable,
+)
+from airflow.models.pool import Pool
+from airflow.sdk.definitions._internal.abstractoperator import (
+    DEFAULT_EXECUTOR,
+    DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
+    DEFAULT_OWNER,
+    DEFAULT_POOL_SLOTS,
+    DEFAULT_PRIORITY_WEIGHT,
+    DEFAULT_QUEUE,
+    DEFAULT_RETRIES,
+    DEFAULT_RETRY_DELAY,
+    DEFAULT_TRIGGER_RULE,
+    DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
+    DEFAULT_WEIGHT_RULE,
+    AbstractOperator,
+)
+from airflow.serialization.enums import DagAttributeTypes
+from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
+from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.triggers.base import StartTriggerArgs
+from airflow.typing_compat import Literal
+from airflow.utils.context import context_update_for_unmapped
+from airflow.utils.helpers import is_container, prevent_duplicates
+from airflow.utils.task_instance_session import 
get_current_task_instance_session
+from airflow.utils.types import NOTSET
+from airflow.utils.xcom import XCOM_RETURN_KEY
+
+if TYPE_CHECKING:
+    import datetime
+
+    import jinja2  # Slow import.
+    import pendulum
+    from sqlalchemy.orm.session import Session
+
+    from airflow.models.abstractoperator import (
+        TaskStateChangeCallback,
+    )
+    from airflow.models.baseoperatorlink import BaseOperatorLink
+    from airflow.models.expandinput import (
+        ExpandInput,
+        OperatorExpandArgument,
+        OperatorExpandKwargsArgument,
+    )
+    from airflow.models.param import ParamsDict
+    from airflow.models.xcom_arg import XComArg
+    from airflow.sdk.definitions.abstractoperator import Operator
+    from airflow.sdk.definitions.baseoperator import BaseOperator
+    from airflow.sdk.definitions.dag import DAG
+    from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+    from airflow.utils.context import Context
+    from airflow.utils.operator_resources import Resources
+    from airflow.utils.task_group import TaskGroup
+    from airflow.utils.trigger_rule import TriggerRule
+
+    TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, 
list[TaskStateChangeCallback]]
+
+ValidationSource = Union[Literal["expand"], Literal["partial"]]
+
+
+def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, 
value: dict[str, Any]) -> None:
+    # use a dict so order of args is same as code order
+    unknown_args = value.copy()
+    for klass in op.mro():
+        init = klass.__init__  # type: ignore[misc]
+        try:
+            param_names = init._BaseOperatorMeta__param_names
+        except AttributeError:
+            continue
+        for name in param_names:
+            value = unknown_args.pop(name, NOTSET)
+            if func != "expand":
+                continue
+            if value is NOTSET:
+                continue
+            if is_mappable(value):
+                continue
+            type_name = type(value).__name__
+            error = f"{op.__name__}.expand() got an unexpected type 
{type_name!r} for keyword argument {name}"
+            raise ValueError(error)
+        if not unknown_args:
+            return  # If we have no args left to check: stop looking at the 
MRO chain.
+
+    if len(unknown_args) == 1:
+        error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
+    else:
+        names = ", ".join(repr(n) for n in unknown_args)
+        error = f"unexpected keyword arguments {names}"
+    raise TypeError(f"{op.__name__}.{func}() got {error}")
+
+
+def ensure_xcomarg_return_value(arg: Any) -> None:
+    from airflow.sdk.definitions.xcom_arg import XComArg
+
+    if isinstance(arg, XComArg):
+        for operator, key in arg.iter_references():
+            if key != XCOM_RETURN_KEY:
+                raise ValueError(f"cannot map over XCom with custom key 
{key!r} from {operator}")
+    elif not is_container(arg):
+        return
+    elif isinstance(arg, collections.abc.Mapping):
+        for v in arg.values():
+            ensure_xcomarg_return_value(v)
+    elif isinstance(arg, collections.abc.Iterable):
+        for v in arg:
+            ensure_xcomarg_return_value(v)
+
+
[email protected](kw_only=True, repr=False)
+class OperatorPartial:
+    """
+    An "intermediate state" returned by ``BaseOperator.partial()``.
+
+    This only exists at DAG-parsing time; the only intended usage is for the
+    user to call ``.expand()`` on it at some point (usually in a method chain) 
to
+    create a ``MappedOperator`` to add into the DAG.
+    """
+
+    operator_class: type[BaseOperator]
+    kwargs: dict[str, Any]
+    params: ParamsDict | dict
+
+    _expand_called: bool = False  # Set when expand() is called to ease user 
debugging.
+
+    def __attrs_post_init__(self):
+        validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
+
+    def __repr__(self) -> str:
+        args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
+        return f"{self.operator_class.__name__}.partial({args})"
+
+    def __del__(self):
+        if not self._expand_called:
+            try:
+                task_id = repr(self.kwargs["task_id"])
+            except KeyError:
+                task_id = f"at {hex(id(self))}"
+            warnings.warn(f"Task {task_id} was never mapped!", 
category=UserWarning, stacklevel=1)
+
+    def expand(self, **mapped_kwargs: OperatorExpandArgument) -> 
MappedOperator:
+        if not mapped_kwargs:
+            raise TypeError("no arguments to expand against")
+        validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
+        prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable 
or already specified")
+        # Since the input is already checked at parse time, we can set strict
+        # to False to skip the checks on execution.
+        return self._expand(DictOfListsExpandInput(mapped_kwargs), 
strict=False)
+
+    def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: 
bool = True) -> MappedOperator:
+        from airflow.models.xcom_arg import XComArg
+
+        if isinstance(kwargs, collections.abc.Sequence):
+            for item in kwargs:
+                if not isinstance(item, (XComArg, collections.abc.Mapping)):
+                    raise TypeError(f"expected XComArg or list[dict], not 
{type(kwargs).__name__}")
+        elif not isinstance(kwargs, XComArg):
+            raise TypeError(f"expected XComArg or list[dict], not 
{type(kwargs).__name__}")
+        return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
+    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> 
MappedOperator:
+        from airflow.operators.empty import EmptyOperator
+
+        self._expand_called = True
+        ensure_xcomarg_return_value(expand_input.value)
+
+        partial_kwargs = self.kwargs.copy()
+        task_id = partial_kwargs.pop("task_id")
+        dag = partial_kwargs.pop("dag")
+        task_group = partial_kwargs.pop("task_group")
+        start_date = partial_kwargs.pop("start_date", None)
+        end_date = partial_kwargs.pop("end_date", None)
+
+        try:
+            operator_name = self.operator_class.custom_operator_name  # type: 
ignore
+        except AttributeError:
+            operator_name = self.operator_class.__name__
+
+        op = MappedOperator(
+            operator_class=self.operator_class,
+            expand_input=expand_input,
+            partial_kwargs=partial_kwargs,
+            task_id=task_id,
+            params=self.params,
+            deps=MappedOperator.deps_for(self.operator_class),
+            operator_extra_links=self.operator_class.operator_extra_links,
+            template_ext=self.operator_class.template_ext,
+            template_fields=self.operator_class.template_fields,
+            
template_fields_renderers=self.operator_class.template_fields_renderers,
+            ui_color=self.operator_class.ui_color,
+            ui_fgcolor=self.operator_class.ui_fgcolor,
+            is_empty=issubclass(self.operator_class, EmptyOperator),
+            task_module=self.operator_class.__module__,
+            task_type=self.operator_class.__name__,
+            operator_name=operator_name,
+            dag=dag,
+            task_group=task_group,
+            start_date=start_date,
+            end_date=end_date,
+            disallow_kwargs_override=strict,
+            # For classic operators, this points to expand_input because kwargs
+            # to BaseOperator.expand() contribute to operator arguments.
+            expand_input_attr="expand_input",
+            start_trigger_args=self.operator_class.start_trigger_args,
+            start_from_trigger=self.operator_class.start_from_trigger,
+        )
+        return op
+
+
[email protected](
+    kw_only=True,
+    # Disable custom __getstate__ and __setstate__ generation since it 
interacts
+    # badly with Airflow's DAG serialization and pickling. When a mapped task 
is
+    # deserialized, subclasses are coerced into MappedOperator, but when it 
goes
+    # through DAG pickling, all attributes defined in the subclasses are 
dropped
+    # by attrs's custom state management. Since attrs does not do anything too
+    # special here (the logic is only important for slots=True), we use 
Python's
+    # built-in implementation, which works (as proven by good old 
BaseOperator).
+    getstate_setstate=False,
+)
+class MappedOperator(AbstractOperator):
+    """Object representing a mapped operator in a DAG."""
+
+    # This attribute serves double purpose. For a "normal" operator instance
+    # loaded from DAG, this holds the underlying non-mapped operator class that
+    # can be used to create an unmapped operator for execution. For an operator
+    # recreated from a serialized DAG, however, this holds the serialized data
+    # that can be used to unmap this into a SerializedBaseOperator.
+    operator_class: type[BaseOperator] | dict[str, Any]
+
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+
+    # Needed for serialization.
+    task_id: str
+    params: ParamsDict | dict
+    deps: frozenset[BaseTIDep]
+    operator_extra_links: Collection[BaseOperatorLink]
+    template_ext: Sequence[str]
+    template_fields: Collection[str]
+    template_fields_renderers: dict[str, str]
+    ui_color: str
+    ui_fgcolor: str
+    _is_empty: bool
+    _task_module: str
+    _task_type: str
+    _operator_name: str
+    start_trigger_args: StartTriggerArgs | None
+    start_from_trigger: bool
+    _needs_expansion: bool = True
+
+    dag: DAG | None
+    task_group: TaskGroup | None
+    start_date: pendulum.DateTime | None
+    end_date: pendulum.DateTime | None
+    upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
+    downstream_task_ids: set[str] = attr.ib(factory=set, init=False)

Review Comment:
   I think I'll do this here in this PR, since I'm already changing these and 
it shows up as a new file anyway to git



##########
task_sdk/src/airflow/sdk/definitions/mappedoperator.py:
##########
@@ -0,0 +1,898 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import collections.abc
+import contextlib
+import copy
+import warnings
+from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar, Union
+
+import attr
+import methodtools
+
+from airflow.exceptions import UnmappableOperator
+from airflow.models.abstractoperator import NotMapped
+from airflow.models.expandinput import (
+    DictOfListsExpandInput,
+    ListOfDictsExpandInput,
+    is_mappable,
+)
+from airflow.models.pool import Pool

Review Comment:
   Good catch



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to