This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a041a2a3671 Enable Serde for Pydantic BaseModel and Subclasses (#51059)
a041a2a3671 is described below
commit a041a2a3671337bbf743ec671542cf35e1e6c42f
Author: Kevin Yang <[email protected]>
AuthorDate: Thu Jun 26 07:45:59 2025 -0400
Enable Serde for Pydantic BaseModel and Subclasses (#51059)
This adds serialization and deserialization support for arbitrary pydantic
objects, while still maintaining security.
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow-core/src/airflow/serialization/serde.py | 13 ++-
.../airflow/serialization/serializers/bignum.py | 8 +-
.../airflow/serialization/serializers/builtin.py | 12 +--
.../airflow/serialization/serializers/datetime.py | 12 +--
.../airflow/serialization/serializers/deltalake.py | 6 +-
.../airflow/serialization/serializers/iceberg.py | 6 +-
.../src/airflow/serialization/serializers/numpy.py | 10 +-
.../airflow/serialization/serializers/pandas.py | 13 ++-
.../airflow/serialization/serializers/pydantic.py | 79 ++++++++++++++
.../airflow/serialization/serializers/timezone.py | 6 +-
airflow-core/src/airflow/serialization/typing.py | 32 ++++++
.../serialization/serializers/test_serializers.py | 115 +++++++++++++++++----
.../tests/unit/serialization/test_serde.py | 15 +++
13 files changed, 271 insertions(+), 56 deletions(-)
diff --git a/airflow-core/src/airflow/serialization/serde.py
b/airflow-core/src/airflow/serialization/serde.py
index 0268ad91206..0f1e948d8db 100644
--- a/airflow-core/src/airflow/serialization/serde.py
+++ b/airflow-core/src/airflow/serialization/serde.py
@@ -32,6 +32,7 @@ import attr
import airflow.serialization.serializers
from airflow.configuration import conf
+from airflow.serialization.typing import is_pydantic_model
from airflow.stats import Stats
from airflow.utils.module_loading import import_string, iter_namespace,
qualname
@@ -52,6 +53,7 @@ OLD_TYPE = "__type"
OLD_SOURCE = "__source"
OLD_DATA = "__var"
OLD_DICT = "dict"
+PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
DEFAULT_VERSION = 0
@@ -145,6 +147,12 @@ def serialize(o: object, depth: int = 0) -> U | None:
qn = "builtins.tuple"
classname = qn
+ if is_pydantic_model(o):
+ # to match the generic Pydantic serializer and deserializer in
_serializers and _deserializers
+ qn = PYDANTIC_MODEL_QUALNAME
+ # the actual Pydantic model class to encode
+ classname = qualname(o)
+
# if there is a builtin serializer available use that
if qn in _serializers:
data, serialized_classname, version, is_serialized =
_serializers[qn].serialize(o)
@@ -256,7 +264,10 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
# registered deserializer
if classname in _deserializers:
- return _deserializers[classname].deserialize(classname, version,
deserialize(value))
+ return _deserializers[classname].deserialize(cls, version,
deserialize(value))
+ if is_pydantic_model(cls):
+ if PYDANTIC_MODEL_QUALNAME in _deserializers:
+ return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls,
version, deserialize(value))
# class has deserialization function
if hasattr(cls, "deserialize"):
diff --git a/airflow-core/src/airflow/serialization/serializers/bignum.py
b/airflow-core/src/airflow/serialization/serializers/bignum.py
index 769e78491e9..5bb89cb386c 100644
--- a/airflow-core/src/airflow/serialization/serializers/bignum.py
+++ b/airflow-core/src/airflow/serialization/serializers/bignum.py
@@ -47,13 +47,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return float(o), name, __version__, True
-def deserialize(classname: str, version: int, data: object) -> decimal.Decimal:
+def deserialize(cls: type, version: int, data: object) -> decimal.Decimal:
from decimal import Decimal
if version > __version__:
- raise TypeError(f"serialized {version} of {classname} > {__version__}")
+ raise TypeError(f"serialized {version} of {qualname(cls)} >
{__version__}")
- if classname != qualname(Decimal):
- raise TypeError(f"{classname} != {qualname(Decimal)}")
+ if cls is not Decimal:
+ raise TypeError(f"do not know how to deserialize {qualname(cls)}")
return Decimal(str(data))
diff --git a/airflow-core/src/airflow/serialization/serializers/builtin.py
b/airflow-core/src/airflow/serialization/serializers/builtin.py
index b0ee8cb713d..076831a05da 100644
--- a/airflow-core/src/airflow/serialization/serializers/builtin.py
+++ b/airflow-core/src/airflow/serialization/serializers/builtin.py
@@ -35,20 +35,20 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return list(cast("list", o)), qualname(o), __version__, True
-def deserialize(classname: str, version: int, data: list) -> tuple | set |
frozenset:
+def deserialize(cls: type, version: int, data: list) -> tuple | set |
frozenset:
if version > __version__:
- raise TypeError("serialized version is newer than class version")
+ raise TypeError(f"serialized version {version} is newer than class
version {__version__}")
- if classname == qualname(tuple):
+ if cls is tuple:
return tuple(data)
- if classname == qualname(set):
+ if cls is set:
return set(data)
- if classname == qualname(frozenset):
+ if cls is frozenset:
return frozenset(data)
- raise TypeError(f"do not know how to deserialize {classname}")
+ raise TypeError(f"do not know how to deserialize {qualname(cls)}")
def stringify(classname: str, version: int, data: list) -> str:
diff --git a/airflow-core/src/airflow/serialization/serializers/datetime.py
b/airflow-core/src/airflow/serialization/serializers/datetime.py
index 69058b8c02a..38009bbc468 100644
--- a/airflow-core/src/airflow/serialization/serializers/datetime.py
+++ b/airflow-core/src/airflow/serialization/serializers/datetime.py
@@ -59,7 +59,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False
-def deserialize(classname: str, version: int, data: dict | str) ->
datetime.date | datetime.timedelta:
+def deserialize(cls: type, version: int, data: dict | str) -> datetime.date |
datetime.timedelta:
import datetime
from pendulum import DateTime
@@ -86,16 +86,16 @@ def deserialize(classname: str, version: int, data: dict |
str) -> datetime.date
else None
)
- if classname == qualname(datetime.datetime) and isinstance(data, dict):
+ if cls is datetime.datetime and isinstance(data, dict):
return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)
- if classname == qualname(DateTime) and isinstance(data, dict):
+ if cls is DateTime and isinstance(data, dict):
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)
- if classname == qualname(datetime.timedelta) and isinstance(data, (str,
float)):
+ if cls is datetime.timedelta and isinstance(data, (str, float)):
return datetime.timedelta(seconds=float(data))
- if classname == qualname(datetime.date) and isinstance(data, str):
+ if cls is datetime.date and isinstance(data, str):
return datetime.date.fromisoformat(data)
- raise TypeError(f"unknown date/time format {classname}")
+ raise TypeError(f"unknown date/time format {qualname(cls)}")
diff --git a/airflow-core/src/airflow/serialization/serializers/deltalake.py
b/airflow-core/src/airflow/serialization/serializers/deltalake.py
index 60456baf800..a79b2317881 100644
--- a/airflow-core/src/airflow/serialization/serializers/deltalake.py
+++ b/airflow-core/src/airflow/serialization/serializers/deltalake.py
@@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return data, qualname(o), __version__, True
-def deserialize(classname: str, version: int, data: dict):
+def deserialize(cls: type, version: int, data: dict):
from deltalake.table import DeltaTable
from airflow.models.crypto import get_fernet
@@ -63,7 +63,7 @@ def deserialize(classname: str, version: int, data: dict):
if version > __version__:
raise TypeError("serialized version is newer than class version")
- if classname == qualname(DeltaTable):
+ if cls is DeltaTable:
fernet = get_fernet()
properties = {}
for k, v in data["storage_options"].items():
@@ -76,4 +76,4 @@ def deserialize(classname: str, version: int, data: dict):
return DeltaTable(data["table_uri"], version=data["version"],
storage_options=storage_options)
- raise TypeError(f"do not know how to deserialize {classname}")
+ raise TypeError(f"do not know how to deserialize {qualname(cls)}")
diff --git a/airflow-core/src/airflow/serialization/serializers/iceberg.py
b/airflow-core/src/airflow/serialization/serializers/iceberg.py
index 3b03381fef3..018732c29fe 100644
--- a/airflow-core/src/airflow/serialization/serializers/iceberg.py
+++ b/airflow-core/src/airflow/serialization/serializers/iceberg.py
@@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return data, qualname(o), __version__, True
-def deserialize(classname: str, version: int, data: dict):
+def deserialize(cls: type, version: int, data: dict):
from pyiceberg.catalog import load_catalog
from pyiceberg.table import Table
@@ -64,7 +64,7 @@ def deserialize(classname: str, version: int, data: dict):
if version > __version__:
raise TypeError("serialized version is newer than class version")
- if classname == qualname(Table):
+ if cls is Table:
fernet = get_fernet()
properties = {}
for k, v in data["catalog_properties"].items():
@@ -73,4 +73,4 @@ def deserialize(classname: str, version: int, data: dict):
catalog = load_catalog(data["identifier"][0], **properties)
return catalog.load_table((data["identifier"][1],
data["identifier"][2]))
- raise TypeError(f"do not know how to deserialize {classname}")
+ raise TypeError(f"do not know how to deserialize {qualname(cls)}")
diff --git a/airflow-core/src/airflow/serialization/serializers/numpy.py
b/airflow-core/src/airflow/serialization/serializers/numpy.py
index c31244c5878..35620692e45 100644
--- a/airflow-core/src/airflow/serialization/serializers/numpy.py
+++ b/airflow-core/src/airflow/serialization/serializers/numpy.py
@@ -80,11 +80,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False
-def deserialize(classname: str, version: int, data: str) -> Any:
+def deserialize(cls: type, version: int, data: str) -> Any:
if version > __version__:
raise TypeError("serialized version is newer than class version")
- if classname not in deserializers:
- raise TypeError(f"unsupported {classname} found for numpy
deserialization")
+ allowed_deserialize_classes = [import_string(classname) for classname in
deserializers]
- return import_string(classname)(data)
+ if cls not in allowed_deserialize_classes:
+ raise TypeError(f"unsupported {qualname(cls)} found for numpy
deserialization")
+
+ return cls(data)
diff --git a/airflow-core/src/airflow/serialization/serializers/pandas.py
b/airflow-core/src/airflow/serialization/serializers/pandas.py
index d805e4b95c0..73f64ce86b4 100644
--- a/airflow-core/src/airflow/serialization/serializers/pandas.py
+++ b/airflow-core/src/airflow/serialization/serializers/pandas.py
@@ -53,17 +53,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True
-def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
+def deserialize(cls: type, version: int, data: object) -> pd.DataFrame:
if version > __version__:
- raise TypeError(f"serialized {version} of {classname} > {__version__}")
+ raise TypeError(f"serialized {version} of {qualname(cls)} >
{__version__}")
- from pyarrow import parquet as pq
+ import pandas as pd
+
+ if cls is not pd.DataFrame:
+ raise TypeError(f"do not know how to deserialize {qualname(cls)}")
if not isinstance(data, str):
- raise TypeError(f"serialized {classname} has wrong data type
{type(data)}")
+ raise TypeError(f"serialized {qualname(cls)} has wrong data type
{type(data)}")
from io import BytesIO
+ from pyarrow import parquet as pq
+
with BytesIO(bytes.fromhex(data)) as buf:
df = pq.read_table(buf).to_pandas()
diff --git a/airflow-core/src/airflow/serialization/serializers/pydantic.py
b/airflow-core/src/airflow/serialization/serializers/pydantic.py
new file mode 100644
index 00000000000..b1a483a152a
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/serializers/pydantic.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+from airflow.serialization.typing import is_pydantic_model
+from airflow.utils.module_loading import qualname
+
+if TYPE_CHECKING:
+ from pydantic import BaseModel
+
+ from airflow.serialization.serde import U
+
+serializers = [
+ "pydantic.main.BaseModel",
+]
+deserializers = serializers
+
+__version__ = 1
+
+
+def serialize(o: object) -> tuple[U, str, int, bool]:
+ """
+ Serialize a Pydantic BaseModel instance into a dict of built-in types.
+
+ Returns a tuple of:
+ - serialized data (as built-in types)
+ - fixed class name for registration (BaseModel)
+ - version number
+ - is_serialized flag (True if handled)
+ """
+ if not is_pydantic_model(o):
+ return "", "", 0, False
+
+ model = cast("BaseModel", o) # for mypy
+ data = model.model_dump()
+
+ return data, qualname(o), __version__, True
+
+
+def deserialize(cls: type, version: int, data: dict):
+ """
+ Deserialize a Pydantic class.
+
+ Pydantic models can be serialized into a Python dictionary via
`pydantic.main.BaseModel.model_dump`
+ and the dictionary can be deserialized through
`pydantic.main.BaseModel.model_validate`. This function
+ can deserialize arbitrary Pydantic models that are in
`allowed_deserialization_classes`.
+
+ :param cls: The actual model class
+ :param version: Serialization version (must not exceed __version__)
+ :param data: Dictionary with built-in types, typically from model_dump()
+ :return: An instance of the actual Pydantic model
+ """
+ if version > __version__:
+ raise TypeError(f"Serialized version {version} is newer than the
supported version {__version__}")
+
+ if not is_pydantic_model(cls):
+ # no deserializer available
+ raise TypeError(f"No deserializer found for {qualname(cls)}")
+
+ # Perform validation-based reconstruction
+ model = cast("BaseModel", cls) # for mypy
+ return model.model_validate(data)
diff --git a/airflow-core/src/airflow/serialization/serializers/timezone.py
b/airflow-core/src/airflow/serialization/serializers/timezone.py
index 9f2ef7cef65..8bc067bc2ff 100644
--- a/airflow-core/src/airflow/serialization/serializers/timezone.py
+++ b/airflow-core/src/airflow/serialization/serializers/timezone.py
@@ -67,16 +67,16 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False
-def deserialize(classname: str, version: int, data: object) -> Any:
+def deserialize(cls: type, version: int, data: object) -> Any:
from airflow.utils.timezone import parse_timezone
if not isinstance(data, (str, int)):
raise TypeError(f"{data} is not of type int or str but of
{type(data)}")
if version > __version__:
- raise TypeError(f"serialized {version} of {classname} > {__version__}")
+ raise TypeError(f"serialized {version} of {qualname(cls)} >
{__version__}")
- if classname == "backports.zoneinfo.ZoneInfo" and isinstance(data, str):
+ if qualname(cls) == "backports.zoneinfo.ZoneInfo" and isinstance(data,
str):
from zoneinfo import ZoneInfo
return ZoneInfo(data)
diff --git a/airflow-core/src/airflow/serialization/typing.py
b/airflow-core/src/airflow/serialization/typing.py
new file mode 100644
index 00000000000..a6169b23a78
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/typing.py
@@ -0,0 +1,32 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+
+def is_pydantic_model(cls: Any) -> bool:
+ """
+ Return True if the class is a pydantic.main.BaseModel.
+
+ Checking is done by attributes as it is significantly faster than
+ using isinstance.
+ """
+ # __pydantic_fields__ is always present on Pydantic V2 models and is a
dict[str, FieldInfo]
+ # __pydantic_validator__ is an internal validator object, always set after
model build
+ return hasattr(cls, "__pydantic_fields__") and hasattr(cls,
"__pydantic_validator__")
diff --git
a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
b/airflow-core/tests/unit/serialization/serializers/test_serializers.py
index 53eb160ec1a..b3d1507f1e4 100644
--- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py
+++ b/airflow-core/tests/unit/serialization/serializers/test_serializers.py
@@ -23,6 +23,7 @@ from unittest.mock import patch
from zoneinfo import ZoneInfo
import numpy as np
+import pandas as pd
import pendulum
import pendulum.tz
import pytest
@@ -31,11 +32,11 @@ from kubernetes.client import models as k8s
from packaging import version
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone
+from pydantic import BaseModel, Field
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify,
decode, deserialize, serialize
from airflow.serialization.serializers import builtin
-from airflow.utils.module_loading import qualname
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
@@ -60,6 +61,13 @@ class NoNameTZ(datetime.tzinfo):
return datetime.timedelta(hours=2)
+class FooBarModel(BaseModel):
+ """Pydantic BaseModel for testing Pydantic
Serialization/Deserialization."""
+
+ banana: float = 1.1
+ foo: str = Field()
+
+
@skip_if_force_lowest_dependencies_marker
class TestSerializers:
def test_datetime(self):
@@ -190,20 +198,26 @@ class TestSerializers:
assert serialize(12345) == ("", "", 0, False)
+ def test_bignum_deserialize_decimal(self):
+ from airflow.serialization.serializers.bignum import deserialize
+
+ res = deserialize(decimal.Decimal, 1, decimal.Decimal(12345))
+ assert res == decimal.Decimal(12345)
+
@pytest.mark.parametrize(
("klass", "version", "payload", "msg"),
[
(
- "decimal.Decimal",
+ decimal.Decimal,
999,
"0",
- r"serialized 999 of decimal\.Decimal", # newer version
+ r"serialized 999 of decimal\.Decimal > 1", # newer version
),
(
- "wrong.ClassName",
+ str,
1,
"0",
- r"wrong\.ClassName != .*Decimal", # wrong classname
+ r"do not know how to deserialize builtins\.str", # wrong
classname
),
],
)
@@ -235,8 +249,8 @@ class TestSerializers:
@pytest.mark.parametrize(
("klass", "ver", "value", "msg"),
[
- ("numpy.int32", 999, 123, r"serialized version is newer"),
- ("numpy.float32", 1, 123, r"unsupported numpy\.float32"),
+ (np.int32, 999, 123, r"serialized version is newer"),
+ (np.float32, 1, 123, r"unsupported numpy\.float32"),
],
)
def test_numpy_deserialize_errors(self, klass, ver, value, msg):
@@ -265,17 +279,23 @@ class TestSerializers:
assert serialize(123) == ("", "", 0, False)
@pytest.mark.parametrize(
- ("version", "data", "msg"),
+ ("klass", "version", "data", "msg"),
[
- (999, "", r"serialized 999 .* > 1"), # version too new
- (1, 123, r"wrong data type .*<class 'int'>"), # bad payload type
+ (pd.DataFrame, 999, "", r"serialized 999 of
pandas.core.frame.DataFrame > 1"), # version too new
+ (
+ pd.DataFrame,
+ 1,
+ 123,
+ r"serialized pandas.core.frame.DataFrame has wrong data type
.*<class 'int'>",
+ ), # bad payload type
+ (str, 1, "", r"do not know how to deserialize builtins.str"), #
bad class
],
)
- def test_pandas_deserialize_errors(self, version, data, msg):
+ def test_pandas_deserialize_errors(self, klass, version, data, msg):
from airflow.serialization.serializers.pandas import deserialize
with pytest.raises(TypeError, match=msg):
- deserialize("pandas.core.frame.DataFrame", version, data)
+ deserialize(klass, version, data)
def test_iceberg(self):
pytest.importorskip("pyiceberg", minversion="2.0.0")
@@ -367,6 +387,32 @@ class TestSerializers:
assert serialize(pod) == ("", "", 0, False)
assert serialize(123) == ("", "", 0, False)
+ def test_pydantic(self):
+ m = FooBarModel(banana=3.14, foo="hello")
+ e = serialize(m)
+ d = deserialize(e)
+
+ assert m.banana == d.banana
+ assert m.foo == d.foo
+
+ @pytest.mark.parametrize(
+ "klass, version, data, msg",
+ [
+ (
+ FooBarModel,
+ 999,
+ FooBarModel(banana=3.14, foo="hello"),
+ "Serialized version 999 is newer than the supported version 1",
+ ),
+ (str, 1, "", r"No deserializer found for builtins\.str"),
+ ],
+ )
+ def test_pydantic_deserialize_errors(self, klass, version, data, msg):
+ from airflow.serialization.serializers.pydantic import deserialize
+
+ with pytest.raises(TypeError, match=msg):
+ deserialize(klass, version, data)
+
@pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3")
@pytest.mark.parametrize(
"ser_value, expected",
@@ -528,15 +574,15 @@ class TestSerializers:
def test_timezone_deserialize_zoneinfo(self):
from airflow.serialization.serializers.timezone import deserialize
- zi = deserialize("backports.zoneinfo.ZoneInfo", 1, "Asia/Taipei")
+ zi = deserialize(ZoneInfo, 1, "Asia/Taipei")
assert isinstance(zi, ZoneInfo)
assert zi.key == "Asia/Taipei"
@pytest.mark.parametrize(
"klass, version, data, msg",
[
- ("pendulum.tz.timezone.FixedTimezone", 1, 1.23, "is not of type
int or str"),
- ("pendulum.tz.timezone.FixedTimezone", 999, "UTC", "serialized 999
.* > 1"),
+ (FixedTimezone, 1, 1.23, "is not of type int or str"),
+ (FixedTimezone, 999, "UTC", "serialized 999 .* > 1"),
],
)
def test_timezone_deserialize_errors(self, klass, version, data, msg):
@@ -570,14 +616,39 @@ class TestSerializers:
load_dag_schema_dict()
assert "Schema file schema.json does not exists" in str(ctx.value)
- def test_builtin_deserialize_frozenset(self):
- res = builtin.deserialize(qualname(frozenset), 1, [13, 14])
- assert isinstance(res, frozenset)
- assert res == frozenset({13, 14})
+ @pytest.mark.parametrize(
+ "klass, version, data",
+ [(tuple, 1, [11, 12]), (set, 1, [11, 12]), (frozenset, 1, [11, 12])],
+ )
+ def test_builtin_deserialize(self, klass, version, data):
+ res = builtin.deserialize(klass, version, klass(data))
+ assert isinstance(res, klass)
+ assert res == klass(data)
- def test_builtin_deserialize_version_too_new(self):
- with pytest.raises(TypeError, match="serialized version is newer than
class version"):
- builtin.deserialize(qualname(tuple), 999, [1, 2])
+ @pytest.mark.parametrize(
+ "klass, version, data, msg",
+ [
+ (tuple, 999, [11, 12], r"serialized version 999 is newer than
class version 1"),
+ (set, 2, [11, 12], r"serialized version 2 is newer than class
version 1"),
+ (frozenset, 13, [11, 12], r"serialized version 13 is newer than
class version 1"),
+ ],
+ )
+ def test_builtin_deserialize_version_too_new(self, klass, version, data,
msg):
+ with pytest.raises(TypeError, match=msg):
+ builtin.deserialize(klass, version, data)
+
+ @pytest.mark.parametrize(
+ "klass, version, data, msg",
+ [
+ (str, 1, "11, 12", r"do not know how to deserialize
builtins\.str"),
+ (int, 1, 11, r"do not know how to deserialize builtins\.int"),
+ (bool, 1, True, r"do not know how to deserialize builtins\.bool"),
+ (float, 1, 0.999, r"do not know how to deserialize
builtins\.float"),
+ ],
+ )
+ def test_builtin_deserialize_wrong_types(self, klass, version, data, msg):
+ with pytest.raises(TypeError, match=msg):
+ builtin.deserialize(klass, version, data)
@pytest.mark.parametrize(
"func, msg",
diff --git a/airflow-core/tests/unit/serialization/test_serde.py
b/airflow-core/tests/unit/serialization/test_serde.py
index 6b716f581fc..560569689cf 100644
--- a/airflow-core/tests/unit/serialization/test_serde.py
+++ b/airflow-core/tests/unit/serialization/test_serde.py
@@ -259,6 +259,21 @@ class TestSerDe:
assert f"{qualname(Z)} was not found in allow list" in str(ex.value)
+ @conf_vars(
+ {
+ ("core", "allowed_deserialization_classes"): "airflow.*",
+ }
+ )
+ @pytest.mark.usefixtures("recalculate_patterns")
+ def test_allow_list_for_deserialize_pydantic_model(self):
+ # for Pydantic model to be deserialized, it must be in
`allowed_deserialization_classes`
+ i = U(x=7, v=V(w=W(x=42), s=["hello", "world"], t=(1, 2, 3), c=99),
u=("extra", 123))
+ e = serialize(i)
+ with pytest.raises(ImportError) as ex:
+ deserialize(e)
+
+ assert f"{qualname(U)} was not found in allow list" in str(ex.value)
+
@conf_vars(
{
("core", "allowed_deserialization_classes"): "unit.airflow.*",