This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-4-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7e9fd34cd07ec0f26c7a72589e327f48389771ed Author: Tzu-ping Chung <uranu...@gmail.com> AuthorDate: Wed Sep 14 18:01:11 2022 +0800 Handle list when serializing expand_kwargs (#26369) (cherry picked from commit b816a6b243d16da87ca00e443619c75e9f6f5816) --- airflow/serialization/serialized_objects.py | 45 +++++++++++++++++++-- tests/serialization/test_dag_serialization.py | 57 ++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index fb298cc79e..969b6014db 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -17,6 +17,7 @@ """Serialized DAG and BaseOperator""" from __future__ import annotations +import collections.abc import datetime import enum import logging @@ -24,7 +25,7 @@ import warnings import weakref from dataclasses import dataclass from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Type +from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Type, Union import cattr import lazy_object_proxy @@ -207,6 +208,26 @@ class _XComRef(NamedTuple): return deserialize_xcom_arg(self.data, dag) +# These two should be kept in sync. Note that these are intentionally not using +# the type declarations in expandinput.py so we always remember to update +# serialization logic when adding new ExpandInput variants. If you add things to +# the unions, be sure to update _ExpandInputRef to match. +_ExpandInputOriginalValue = Union[ + # For .expand(**kwargs). + Mapping[str, Any], + # For expand_kwargs(arg). + XComArg, + Collection[Union[XComArg, Mapping[str, Any]]], +] +_ExpandInputSerializedValue = Union[ + # For .expand(**kwargs). + Mapping[str, Any], + # For expand_kwargs(arg). + _XComRef, + Collection[Union[_XComRef, Mapping[str, Any]]], +] + + class _ExpandInputRef(NamedTuple): """Used to store info needed to create a mapped operator's expand input. @@ -215,13 +236,29 @@ class _ExpandInputRef(NamedTuple): """ key: str - value: _XComRef | dict[str, Any] + value: _ExpandInputSerializedValue + + @classmethod + def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None: + """Validate we've covered all ``ExpandInput.value`` types. + + This function does not actually do anything, but is called during + serialization so Mypy will *statically* check we have handled all + possible ExpandInput cases. + """ def deref(self, dag: DAG) -> ExpandInput: + """De-reference into a concrete ExpandInput object. + + If you add more cases here, be sure to update _ExpandInputOriginalValue + and _ExpandInputSerializedValue to match the logic. + """ if isinstance(self.value, _XComRef): value: Any = self.value.deref(dag) - else: + elif isinstance(self.value, collections.abc.Mapping): value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()} + else: + value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value] return create_expand_input(self.key, value) @@ -663,6 +700,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator)) # Handle expand_input and op_kwargs_expand_input. expansion_kwargs = op._get_specified_expand_input() + if TYPE_CHECKING: # Let Mypy check the input type for us! + _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value) serialized_op[op._expand_input_attr] = { "type": get_map_type_key(expansion_kwargs), "value": cls.serialize(expansion_kwargs.value), diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index bd171fd50d..aa409183e8 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1961,7 +1961,62 @@ def test_operator_expand_xcomarg_serde(): @pytest.mark.parametrize("strict", [True, False]) -def test_operator_expand_kwargs_serde(strict): +def test_operator_expand_kwargs_literal_serde(strict): + from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.serialization.serialized_objects import _XComRef + + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand_kwargs( + [{"a": "x"}, {"a": XComArg(task1)}], + strict=strict, + ) + + serialized = SerializedBaseOperator.serialize(mapped) + assert serialized == { + '_is_empty': False, + '_is_mapped': True, + '_task_module': 'tests.test_utils.mock_operators', + '_task_type': 'MockOperator', + 'downstream_task_ids': [], + 'expand_input': { + "type": "list-of-dicts", + "value": [ + {"__type": "dict", "__var": {"a": "x"}}, + { + "__type": "dict", + "__var": {"a": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}}, + }, + ], + }, + 'partial_kwargs': {}, + 'task_id': 'task_2', + 'template_fields': ['arg1', 'arg2'], + 'template_ext': [], + 'template_fields_renderers': {}, + 'operator_extra_links': [], + 'ui_color': '#fff', + 'ui_fgcolor': '#000', + "_disallow_kwargs_override": strict, + '_expand_input_attr': 'expand_input', + } + + op = SerializedBaseOperator.deserialize_operator(serialized) + assert op.deps is MappedOperator.deps_for(BaseOperator) + assert op._disallow_kwargs_override == strict + + # The XComArg can't be deserialized before the DAG is. + expand_value = op.expand_input.value + assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}] + + serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + resolved_expand_value = serialized_dag.task_dict['task_2'].expand_input.value + resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict['op1'])}] + + +@pytest.mark.parametrize("strict", [True, False]) +def test_operator_expand_kwargs_xcomarg_serde(strict): from airflow.models.xcom_arg import PlainXComArg, XComArg from airflow.serialization.serialized_objects import _XComRef