This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/main by this push:
new 04f309a Allow a single enum to be used as a form field type
04f309a is described below
commit 04f309ac95c675b820e57e900c368bc8640fb2e3
Author: Sean B. Palmer <[email protected]>
AuthorDate: Wed Nov 12 20:23:07 2025 +0000
Allow a single enum to be used as a form field type
---
atr/form.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++++------------
1 file changed, 60 insertions(+), 14 deletions(-)
diff --git a/atr/form.py b/atr/form.py
index 280e5d6..a836e9c 100644
--- a/atr/form.py
+++ b/atr/form.py
@@ -333,6 +333,17 @@ def to_bool(v: Any) -> bool:
raise ValueError(f"Cannot convert {v!r} to boolean")
+def to_enum[EnumType: enum.Enum](v: Any, enum_class: type[EnumType]) ->
EnumType:
+ members: dict[str, EnumType] = {member.value: member for member in
enum_class}
+ if isinstance(v, enum_class):
+ return v
+ if isinstance(v, str):
+ if v in members:
+ return members[v]
+ raise ValueError(f"Invalid enum value: {v!r}")
+ raise ValueError(f"Expected an enum value, got {type(v).__name__}")
+
+
def to_enum_set[EnumType: enum.Enum](v: Any, enum_class: type[EnumType]) ->
set[EnumType]:
members: dict[str, EnumType] = {member.value: member for member in
enum_class}
if isinstance(v, set):
@@ -438,6 +449,21 @@ Bool = Annotated[
Email = pydantic.EmailStr
+class Enum[EnumType: enum.Enum]:
+ @staticmethod
+ def __class_getitem__(enum_class: type[EnumType]):
+ def validator(v: Any) -> EnumType:
+ return to_enum(v, enum_class)
+
+ # Get the first enum member as the default
+ first_member = next(iter(enum_class))
+ return Annotated[
+ enum_class,
+ functional_validators.BeforeValidator(validator),
+ pydantic.Field(default=first_member),
+ ]
+
+
File = Annotated[
datastructures.FileStorage | None,
functional_validators.BeforeValidator(to_filestorage),
@@ -456,6 +482,11 @@ Filename = Annotated[
pydantic.Field(default=None),
]
+Int = Annotated[
+ int,
+ functional_validators.BeforeValidator(to_int),
+]
+
OptionalURL = Annotated[
pydantic.HttpUrl | None,
functional_validators.BeforeValidator(to_optional_url),
@@ -468,14 +499,6 @@ StrList = Annotated[
pydantic.Field(default_factory=list),
]
-Int = Annotated[
- int,
- functional_validators.BeforeValidator(to_int),
-]
-
-
-URL = pydantic.HttpUrl
-
class Set[EnumType: enum.Enum]:
def __iter__(self) -> Iterator[EnumType]:
@@ -494,6 +517,9 @@ class Set[EnumType: enum.Enum]:
]
+URL = pydantic.HttpUrl
+
+
def validate(model_cls: Any, form: dict[str, Any], context: dict[str, Any] |
None = None) -> pydantic.BaseModel:
# Since pydantic.TypeAdapter accepts Any, we do the same
return pydantic.TypeAdapter(model_cls).validate_python(form,
context=context)
@@ -657,7 +683,11 @@ def _render_widget( # noqa: C901
selected_value = field_value[0] if field_value else None
else:
choices = _get_choices(field_info)
- selected_value = field_value
+ # If field_value is an enum, extract its value for comparison
+ if isinstance(field_value, enum.Enum):
+ selected_value = field_value.value
+ else:
+ selected_value = field_value
options = [
htpy.option(
@@ -695,24 +725,38 @@ def _render_widget( # noqa: C901
return htm.div[elements] if len(elements) > 1 else elements[0]
-def _get_choices(field_info: pydantic.fields.FieldInfo) -> list[tuple[str,
str]]:
+def _get_choices(field_info: pydantic.fields.FieldInfo) -> list[tuple[str,
str]]: # noqa: C901
annotation = field_info.annotation
origin = get_origin(annotation)
if origin is Literal:
return [(v, v) for v in get_args(annotation)]
+ if origin is Annotated:
+ # Check whether this is an Enum[T] or Set[T] annotation
+ args = get_args(annotation)
+ if args:
+ inner_type = args[0]
+ if isinstance(inner_type, type) and issubclass(inner_type,
enum.Enum):
+ # This is an enum type wrapped in Annotated, from Enum[T] or
Set[T]
+ return [(member.value, member.value) for member in inner_type]
+
if origin is set:
args = get_args(annotation)
- if args and hasattr(args[0], "__members__"):
+ if args:
enum_class = args[0]
- return [(member.value, member.value) for member in enum_class]
+ if isinstance(enum_class, type) and issubclass(enum_class,
enum.Enum):
+ return [(member.value, member.value) for member in enum_class]
if origin is list:
args = get_args(annotation)
if args and get_origin(args[0]) is Literal:
return [(v, v) for v in get_args(args[0])]
+ # Check for plain enum types, e.g. when Pydantic unwraps form.Enum[T]
+ if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
+ return [(member.value, member.value) for member in annotation]
+
return []
@@ -779,8 +823,10 @@ def _get_widget_type(field_info:
pydantic.fields.FieldInfo) -> Widget: # noqa:
if origin is set:
args = get_args(annotation)
- if args and hasattr(args[0], "__members__"):
- return Widget.CHECKBOXES
+ if args:
+ first_arg = args[0]
+ if isinstance(first_arg, type) and issubclass(first_arg,
enum.Enum):
+ return Widget.CHECKBOXES
if origin is list:
args = get_args(annotation)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]