This is an automated email from the ASF dual-hosted git repository. bolke pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new a041a2a3671 Enable Serde for Pydantic BaseModel and Subclasses (#51059) a041a2a3671 is described below commit a041a2a3671337bbf743ec671542cf35e1e6c42f Author: Kevin Yang <85313829+sjyangke...@users.noreply.github.com> AuthorDate: Thu Jun 26 07:45:59 2025 -0400 Enable Serde for Pydantic BaseModel and Subclasses (#51059) This adds serialization and deserialization support for arbitrary pydantic objects, while still maintaining security. --------- Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> --- airflow-core/src/airflow/serialization/serde.py | 13 ++- .../airflow/serialization/serializers/bignum.py | 8 +- .../airflow/serialization/serializers/builtin.py | 12 +-- .../airflow/serialization/serializers/datetime.py | 12 +-- .../airflow/serialization/serializers/deltalake.py | 6 +- .../airflow/serialization/serializers/iceberg.py | 6 +- .../src/airflow/serialization/serializers/numpy.py | 10 +- .../airflow/serialization/serializers/pandas.py | 13 ++- .../airflow/serialization/serializers/pydantic.py | 79 ++++++++++++++ .../airflow/serialization/serializers/timezone.py | 6 +- airflow-core/src/airflow/serialization/typing.py | 32 ++++++ .../serialization/serializers/test_serializers.py | 115 +++++++++++++++++---- .../tests/unit/serialization/test_serde.py | 15 +++ 13 files changed, 271 insertions(+), 56 deletions(-) diff --git a/airflow-core/src/airflow/serialization/serde.py b/airflow-core/src/airflow/serialization/serde.py index 0268ad91206..0f1e948d8db 100644 --- a/airflow-core/src/airflow/serialization/serde.py +++ b/airflow-core/src/airflow/serialization/serde.py @@ -32,6 +32,7 @@ import attr import airflow.serialization.serializers from airflow.configuration import conf +from airflow.serialization.typing import is_pydantic_model from airflow.stats import Stats from airflow.utils.module_loading import import_string, iter_namespace, qualname @@ -52,6 +53,7 @@ OLD_TYPE = "__type" OLD_SOURCE = "__source" OLD_DATA = "__var" OLD_DICT = "dict" +PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel" DEFAULT_VERSION = 0 @@ -145,6 +147,12 @@ def serialize(o: object, depth: int = 0) -> U | None: qn = "builtins.tuple" classname = qn + if is_pydantic_model(o): + # to match the generic Pydantic serializer and deserializer in _serializers and _deserializers + qn = PYDANTIC_MODEL_QUALNAME + # the actual Pydantic model class to encode + classname = qualname(o) + # if there is a builtin serializer available use that if qn in _serializers: data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o) @@ -256,7 +264,10 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: # registered deserializer if classname in _deserializers: - return _deserializers[classname].deserialize(classname, version, deserialize(value)) + return _deserializers[classname].deserialize(cls, version, deserialize(value)) + if is_pydantic_model(cls): + if PYDANTIC_MODEL_QUALNAME in _deserializers: + return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value)) # class has deserialization function if hasattr(cls, "deserialize"): diff --git a/airflow-core/src/airflow/serialization/serializers/bignum.py b/airflow-core/src/airflow/serialization/serializers/bignum.py index 769e78491e9..5bb89cb386c 100644 --- a/airflow-core/src/airflow/serialization/serializers/bignum.py +++ b/airflow-core/src/airflow/serialization/serializers/bignum.py @@ -47,13 +47,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return float(o), name, __version__, True -def deserialize(classname: str, version: int, data: object) -> decimal.Decimal: +def deserialize(cls: type, version: int, data: object) -> decimal.Decimal: from decimal import Decimal if version > __version__: - raise TypeError(f"serialized {version} of {classname} > {__version__}") + raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}") - if classname != qualname(Decimal): - raise TypeError(f"{classname} != {qualname(Decimal)}") + if cls is not Decimal: + raise TypeError(f"do not know how to deserialize {qualname(cls)}") return Decimal(str(data)) diff --git a/airflow-core/src/airflow/serialization/serializers/builtin.py b/airflow-core/src/airflow/serialization/serializers/builtin.py index b0ee8cb713d..076831a05da 100644 --- a/airflow-core/src/airflow/serialization/serializers/builtin.py +++ b/airflow-core/src/airflow/serialization/serializers/builtin.py @@ -35,20 +35,20 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return list(cast("list", o)), qualname(o), __version__, True -def deserialize(classname: str, version: int, data: list) -> tuple | set | frozenset: +def deserialize(cls: type, version: int, data: list) -> tuple | set | frozenset: if version > __version__: - raise TypeError("serialized version is newer than class version") + raise TypeError(f"serialized version {version} is newer than class version {__version__}") - if classname == qualname(tuple): + if cls is tuple: return tuple(data) - if classname == qualname(set): + if cls is set: return set(data) - if classname == qualname(frozenset): + if cls is frozenset: return frozenset(data) - raise TypeError(f"do not know how to deserialize {classname}") + raise TypeError(f"do not know how to deserialize {qualname(cls)}") def stringify(classname: str, version: int, data: list) -> str: diff --git a/airflow-core/src/airflow/serialization/serializers/datetime.py b/airflow-core/src/airflow/serialization/serializers/datetime.py index 69058b8c02a..38009bbc468 100644 --- a/airflow-core/src/airflow/serialization/serializers/datetime.py +++ b/airflow-core/src/airflow/serialization/serializers/datetime.py @@ -59,7 +59,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return "", "", 0, False -def deserialize(classname: str, version: int, data: dict | str) -> datetime.date | datetime.timedelta: +def deserialize(cls: type, version: int, data: dict | str) -> datetime.date | datetime.timedelta: import datetime from pendulum import DateTime @@ -86,16 +86,16 @@ def deserialize(classname: str, version: int, data: dict | str) -> datetime.date else None ) - if classname == qualname(datetime.datetime) and isinstance(data, dict): + if cls is datetime.datetime and isinstance(data, dict): return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=tz) - if classname == qualname(DateTime) and isinstance(data, dict): + if cls is DateTime and isinstance(data, dict): return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz) - if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)): + if cls is datetime.timedelta and isinstance(data, (str, float)): return datetime.timedelta(seconds=float(data)) - if classname == qualname(datetime.date) and isinstance(data, str): + if cls is datetime.date and isinstance(data, str): return datetime.date.fromisoformat(data) - raise TypeError(f"unknown date/time format {classname}") + raise TypeError(f"unknown date/time format {qualname(cls)}") diff --git a/airflow-core/src/airflow/serialization/serializers/deltalake.py b/airflow-core/src/airflow/serialization/serializers/deltalake.py index 60456baf800..a79b2317881 100644 --- a/airflow-core/src/airflow/serialization/serializers/deltalake.py +++ b/airflow-core/src/airflow/serialization/serializers/deltalake.py @@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return data, qualname(o), __version__, True -def deserialize(classname: str, version: int, data: dict): +def deserialize(cls: type, version: int, data: dict): from deltalake.table import DeltaTable from airflow.models.crypto import get_fernet @@ -63,7 +63,7 @@ def deserialize(classname: str, version: int, data: dict): if version > __version__: raise TypeError("serialized version is newer than class version") - if classname == qualname(DeltaTable): + if cls is DeltaTable: fernet = get_fernet() properties = {} for k, v in data["storage_options"].items(): @@ -76,4 +76,4 @@ def deserialize(classname: str, version: int, data: dict): return DeltaTable(data["table_uri"], version=data["version"], storage_options=storage_options) - raise TypeError(f"do not know how to deserialize {classname}") + raise TypeError(f"do not know how to deserialize {qualname(cls)}") diff --git a/airflow-core/src/airflow/serialization/serializers/iceberg.py b/airflow-core/src/airflow/serialization/serializers/iceberg.py index 3b03381fef3..018732c29fe 100644 --- a/airflow-core/src/airflow/serialization/serializers/iceberg.py +++ b/airflow-core/src/airflow/serialization/serializers/iceberg.py @@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return data, qualname(o), __version__, True -def deserialize(classname: str, version: int, data: dict): +def deserialize(cls: type, version: int, data: dict): from pyiceberg.catalog import load_catalog from pyiceberg.table import Table @@ -64,7 +64,7 @@ def deserialize(classname: str, version: int, data: dict): if version > __version__: raise TypeError("serialized version is newer than class version") - if classname == qualname(Table): + if cls is Table: fernet = get_fernet() properties = {} for k, v in data["catalog_properties"].items(): @@ -73,4 +73,4 @@ def deserialize(classname: str, version: int, data: dict): catalog = load_catalog(data["identifier"][0], **properties) return catalog.load_table((data["identifier"][1], data["identifier"][2])) - raise TypeError(f"do not know how to deserialize {classname}") + raise TypeError(f"do not know how to deserialize {qualname(cls)}") diff --git a/airflow-core/src/airflow/serialization/serializers/numpy.py b/airflow-core/src/airflow/serialization/serializers/numpy.py index c31244c5878..35620692e45 100644 --- a/airflow-core/src/airflow/serialization/serializers/numpy.py +++ b/airflow-core/src/airflow/serialization/serializers/numpy.py @@ -80,11 +80,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return "", "", 0, False -def deserialize(classname: str, version: int, data: str) -> Any: +def deserialize(cls: type, version: int, data: str) -> Any: if version > __version__: raise TypeError("serialized version is newer than class version") - if classname not in deserializers: - raise TypeError(f"unsupported {classname} found for numpy deserialization") + allowed_deserialize_classes = [import_string(classname) for classname in deserializers] - return import_string(classname)(data) + if cls not in allowed_deserialize_classes: + raise TypeError(f"unsupported {qualname(cls)} found for numpy deserialization") + + return cls(data) diff --git a/airflow-core/src/airflow/serialization/serializers/pandas.py b/airflow-core/src/airflow/serialization/serializers/pandas.py index d805e4b95c0..73f64ce86b4 100644 --- a/airflow-core/src/airflow/serialization/serializers/pandas.py +++ b/airflow-core/src/airflow/serialization/serializers/pandas.py @@ -53,17 +53,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True -def deserialize(classname: str, version: int, data: object) -> pd.DataFrame: +def deserialize(cls: type, version: int, data: object) -> pd.DataFrame: if version > __version__: - raise TypeError(f"serialized {version} of {classname} > {__version__}") + raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}") - from pyarrow import parquet as pq + import pandas as pd + + if cls is not pd.DataFrame: + raise TypeError(f"do not know how to deserialize {qualname(cls)}") if not isinstance(data, str): - raise TypeError(f"serialized {classname} has wrong data type {type(data)}") + raise TypeError(f"serialized {qualname(cls)} has wrong data type {type(data)}") from io import BytesIO + from pyarrow import parquet as pq + with BytesIO(bytes.fromhex(data)) as buf: df = pq.read_table(buf).to_pandas() diff --git a/airflow-core/src/airflow/serialization/serializers/pydantic.py b/airflow-core/src/airflow/serialization/serializers/pydantic.py new file mode 100644 index 00000000000..b1a483a152a --- /dev/null +++ b/airflow-core/src/airflow/serialization/serializers/pydantic.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from airflow.serialization.typing import is_pydantic_model +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from pydantic import BaseModel + + from airflow.serialization.serde import U + +serializers = [ + "pydantic.main.BaseModel", +] +deserializers = serializers + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + """ + Serialize a Pydantic BaseModel instance into a dict of built-in types. + + Returns a tuple of: + - serialized data (as built-in types) + - fixed class name for registration (BaseModel) + - version number + - is_serialized flag (True if handled) + """ + if not is_pydantic_model(o): + return "", "", 0, False + + model = cast("BaseModel", o) # for mypy + data = model.model_dump() + + return data, qualname(o), __version__, True + + +def deserialize(cls: type, version: int, data: dict): + """ + Deserialize a Pydantic class. + + Pydantic models can be serialized into a Python dictionary via `pydantic.main.BaseModel.model_dump` + and the dictionary can be deserialized through `pydantic.main.BaseModel.model_validate`. This function + can deserialize arbitrary Pydantic models that are in `allowed_deserialization_classes`. + + :param cls: The actual model class + :param version: Serialization version (must not exceed __version__) + :param data: Dictionary with built-in types, typically from model_dump() + :return: An instance of the actual Pydantic model + """ + if version > __version__: + raise TypeError(f"Serialized version {version} is newer than the supported version {__version__}") + + if not is_pydantic_model(cls): + # no deserializer available + raise TypeError(f"No deserializer found for {qualname(cls)}") + + # Perform validation-based reconstruction + model = cast("BaseModel", cls) # for mypy + return model.model_validate(data) diff --git a/airflow-core/src/airflow/serialization/serializers/timezone.py b/airflow-core/src/airflow/serialization/serializers/timezone.py index 9f2ef7cef65..8bc067bc2ff 100644 --- a/airflow-core/src/airflow/serialization/serializers/timezone.py +++ b/airflow-core/src/airflow/serialization/serializers/timezone.py @@ -67,16 +67,16 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return "", "", 0, False -def deserialize(classname: str, version: int, data: object) -> Any: +def deserialize(cls: type, version: int, data: object) -> Any: from airflow.utils.timezone import parse_timezone if not isinstance(data, (str, int)): raise TypeError(f"{data} is not of type int or str but of {type(data)}") if version > __version__: - raise TypeError(f"serialized {version} of {classname} > {__version__}") + raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}") - if classname == "backports.zoneinfo.ZoneInfo" and isinstance(data, str): + if qualname(cls) == "backports.zoneinfo.ZoneInfo" and isinstance(data, str): from zoneinfo import ZoneInfo return ZoneInfo(data) diff --git a/airflow-core/src/airflow/serialization/typing.py b/airflow-core/src/airflow/serialization/typing.py new file mode 100644 index 00000000000..a6169b23a78 --- /dev/null +++ b/airflow-core/src/airflow/serialization/typing.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + + +def is_pydantic_model(cls: Any) -> bool: + """ + Return True if the class is a pydantic.main.BaseModel. + + Checking is done by attributes as it is significantly faster than + using isinstance. + """ + # __pydantic_fields__ is always present on Pydantic V2 models and is a dict[str, FieldInfo] + # __pydantic_validator__ is an internal validator object, always set after model build + return hasattr(cls, "__pydantic_fields__") and hasattr(cls, "__pydantic_validator__") diff --git a/airflow-core/tests/unit/serialization/serializers/test_serializers.py b/airflow-core/tests/unit/serialization/serializers/test_serializers.py index 53eb160ec1a..b3d1507f1e4 100644 --- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py +++ b/airflow-core/tests/unit/serialization/serializers/test_serializers.py @@ -23,6 +23,7 @@ from unittest.mock import patch from zoneinfo import ZoneInfo import numpy as np +import pandas as pd import pendulum import pendulum.tz import pytest @@ -31,11 +32,11 @@ from kubernetes.client import models as k8s from packaging import version from pendulum import DateTime from pendulum.tz.timezone import FixedTimezone, Timezone +from pydantic import BaseModel, Field from airflow.sdk.definitions.param import Param, ParamsDict from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify, decode, deserialize, serialize from airflow.serialization.serializers import builtin -from airflow.utils.module_loading import qualname from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker @@ -60,6 +61,13 @@ class NoNameTZ(datetime.tzinfo): return datetime.timedelta(hours=2) +class FooBarModel(BaseModel): + """Pydantic BaseModel for testing Pydantic Serialization/Deserialization.""" + + banana: float = 1.1 + foo: str = Field() + + @skip_if_force_lowest_dependencies_marker class TestSerializers: def test_datetime(self): @@ -190,20 +198,26 @@ class TestSerializers: assert serialize(12345) == ("", "", 0, False) + def test_bignum_deserialize_decimal(self): + from airflow.serialization.serializers.bignum import deserialize + + res = deserialize(decimal.Decimal, 1, decimal.Decimal(12345)) + assert res == decimal.Decimal(12345) + @pytest.mark.parametrize( ("klass", "version", "payload", "msg"), [ ( - "decimal.Decimal", + decimal.Decimal, 999, "0", - r"serialized 999 of decimal\.Decimal", # newer version + r"serialized 999 of decimal\.Decimal > 1", # newer version ), ( - "wrong.ClassName", + str, 1, "0", - r"wrong\.ClassName != .*Decimal", # wrong classname + r"do not know how to deserialize builtins\.str", # wrong classname ), ], ) @@ -235,8 +249,8 @@ class TestSerializers: @pytest.mark.parametrize( ("klass", "ver", "value", "msg"), [ - ("numpy.int32", 999, 123, r"serialized version is newer"), - ("numpy.float32", 1, 123, r"unsupported numpy\.float32"), + (np.int32, 999, 123, r"serialized version is newer"), + (np.float32, 1, 123, r"unsupported numpy\.float32"), ], ) def test_numpy_deserialize_errors(self, klass, ver, value, msg): @@ -265,17 +279,23 @@ class TestSerializers: assert serialize(123) == ("", "", 0, False) @pytest.mark.parametrize( - ("version", "data", "msg"), + ("klass", "version", "data", "msg"), [ - (999, "", r"serialized 999 .* > 1"), # version too new - (1, 123, r"wrong data type .*<class 'int'>"), # bad payload type + (pd.DataFrame, 999, "", r"serialized 999 of pandas.core.frame.DataFrame > 1"), # version too new + ( + pd.DataFrame, + 1, + 123, + r"serialized pandas.core.frame.DataFrame has wrong data type .*<class 'int'>", + ), # bad payload type + (str, 1, "", r"do not know how to deserialize builtins.str"), # bad class ], ) - def test_pandas_deserialize_errors(self, version, data, msg): + def test_pandas_deserialize_errors(self, klass, version, data, msg): from airflow.serialization.serializers.pandas import deserialize with pytest.raises(TypeError, match=msg): - deserialize("pandas.core.frame.DataFrame", version, data) + deserialize(klass, version, data) def test_iceberg(self): pytest.importorskip("pyiceberg", minversion="2.0.0") @@ -367,6 +387,32 @@ class TestSerializers: assert serialize(pod) == ("", "", 0, False) assert serialize(123) == ("", "", 0, False) + def test_pydantic(self): + m = FooBarModel(banana=3.14, foo="hello") + e = serialize(m) + d = deserialize(e) + + assert m.banana == d.banana + assert m.foo == d.foo + + @pytest.mark.parametrize( + "klass, version, data, msg", + [ + ( + FooBarModel, + 999, + FooBarModel(banana=3.14, foo="hello"), + "Serialized version 999 is newer than the supported version 1", + ), + (str, 1, "", r"No deserializer found for builtins\.str"), + ], + ) + def test_pydantic_deserialize_errors(self, klass, version, data, msg): + from airflow.serialization.serializers.pydantic import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize(klass, version, data) + @pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3") @pytest.mark.parametrize( "ser_value, expected", @@ -528,15 +574,15 @@ class TestSerializers: def test_timezone_deserialize_zoneinfo(self): from airflow.serialization.serializers.timezone import deserialize - zi = deserialize("backports.zoneinfo.ZoneInfo", 1, "Asia/Taipei") + zi = deserialize(ZoneInfo, 1, "Asia/Taipei") assert isinstance(zi, ZoneInfo) assert zi.key == "Asia/Taipei" @pytest.mark.parametrize( "klass, version, data, msg", [ - ("pendulum.tz.timezone.FixedTimezone", 1, 1.23, "is not of type int or str"), - ("pendulum.tz.timezone.FixedTimezone", 999, "UTC", "serialized 999 .* > 1"), + (FixedTimezone, 1, 1.23, "is not of type int or str"), + (FixedTimezone, 999, "UTC", "serialized 999 .* > 1"), ], ) def test_timezone_deserialize_errors(self, klass, version, data, msg): @@ -570,14 +616,39 @@ class TestSerializers: load_dag_schema_dict() assert "Schema file schema.json does not exists" in str(ctx.value) - def test_builtin_deserialize_frozenset(self): - res = builtin.deserialize(qualname(frozenset), 1, [13, 14]) - assert isinstance(res, frozenset) - assert res == frozenset({13, 14}) + @pytest.mark.parametrize( + "klass, version, data", + [(tuple, 1, [11, 12]), (set, 1, [11, 12]), (frozenset, 1, [11, 12])], + ) + def test_builtin_deserialize(self, klass, version, data): + res = builtin.deserialize(klass, version, klass(data)) + assert isinstance(res, klass) + assert res == klass(data) - def test_builtin_deserialize_version_too_new(self): - with pytest.raises(TypeError, match="serialized version is newer than class version"): - builtin.deserialize(qualname(tuple), 999, [1, 2]) + @pytest.mark.parametrize( + "klass, version, data, msg", + [ + (tuple, 999, [11, 12], r"serialized version 999 is newer than class version 1"), + (set, 2, [11, 12], r"serialized version 2 is newer than class version 1"), + (frozenset, 13, [11, 12], r"serialized version 13 is newer than class version 1"), + ], + ) + def test_builtin_deserialize_version_too_new(self, klass, version, data, msg): + with pytest.raises(TypeError, match=msg): + builtin.deserialize(klass, version, data) + + @pytest.mark.parametrize( + "klass, version, data, msg", + [ + (str, 1, "11, 12", r"do not know how to deserialize builtins\.str"), + (int, 1, 11, r"do not know how to deserialize builtins\.int"), + (bool, 1, True, r"do not know how to deserialize builtins\.bool"), + (float, 1, 0.999, r"do not know how to deserialize builtins\.float"), + ], + ) + def test_builtin_deserialize_wrong_types(self, klass, version, data, msg): + with pytest.raises(TypeError, match=msg): + builtin.deserialize(klass, version, data) @pytest.mark.parametrize( "func, msg", diff --git a/airflow-core/tests/unit/serialization/test_serde.py b/airflow-core/tests/unit/serialization/test_serde.py index 6b716f581fc..560569689cf 100644 --- a/airflow-core/tests/unit/serialization/test_serde.py +++ b/airflow-core/tests/unit/serialization/test_serde.py @@ -259,6 +259,21 @@ class TestSerDe: assert f"{qualname(Z)} was not found in allow list" in str(ex.value) + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "airflow.*", + } + ) + @pytest.mark.usefixtures("recalculate_patterns") + def test_allow_list_for_deserialize_pydantic_model(self): + # for Pydantic model to be deserialized, it must be in `allowed_deserialization_classes` + i = U(x=7, v=V(w=W(x=42), s=["hello", "world"], t=(1, 2, 3), c=99), u=("extra", 123)) + e = serialize(i) + with pytest.raises(ImportError) as ex: + deserialize(e) + + assert f"{qualname(U)} was not found in allow list" in str(ex.value) + @conf_vars( { ("core", "allowed_deserialization_classes"): "unit.airflow.*",