This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git


The following commit(s) were added to refs/heads/main by this push:
     new 04f309a  Allow a single enum to be used as a form field type
04f309a is described below

commit 04f309ac95c675b820e57e900c368bc8640fb2e3
Author: Sean B. Palmer <[email protected]>
AuthorDate: Wed Nov 12 20:23:07 2025 +0000

    Allow a single enum to be used as a form field type
---
 atr/form.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 60 insertions(+), 14 deletions(-)

diff --git a/atr/form.py b/atr/form.py
index 280e5d6..a836e9c 100644
--- a/atr/form.py
+++ b/atr/form.py
@@ -333,6 +333,17 @@ def to_bool(v: Any) -> bool:
     raise ValueError(f"Cannot convert {v!r} to boolean")
 
 
+def to_enum[EnumType: enum.Enum](v: Any, enum_class: type[EnumType]) -> 
EnumType:
+    members: dict[str, EnumType] = {member.value: member for member in 
enum_class}
+    if isinstance(v, enum_class):
+        return v
+    if isinstance(v, str):
+        if v in members:
+            return members[v]
+        raise ValueError(f"Invalid enum value: {v!r}")
+    raise ValueError(f"Expected an enum value, got {type(v).__name__}")
+
+
 def to_enum_set[EnumType: enum.Enum](v: Any, enum_class: type[EnumType]) -> 
set[EnumType]:
     members: dict[str, EnumType] = {member.value: member for member in 
enum_class}
     if isinstance(v, set):
@@ -438,6 +449,21 @@ Bool = Annotated[
 Email = pydantic.EmailStr
 
 
+class Enum[EnumType: enum.Enum]:
+    @staticmethod
+    def __class_getitem__(enum_class: type[EnumType]):
+        def validator(v: Any) -> EnumType:
+            return to_enum(v, enum_class)
+
+        # Get the first enum member as the default
+        first_member = next(iter(enum_class))
+        return Annotated[
+            enum_class,
+            functional_validators.BeforeValidator(validator),
+            pydantic.Field(default=first_member),
+        ]
+
+
 File = Annotated[
     datastructures.FileStorage | None,
     functional_validators.BeforeValidator(to_filestorage),
@@ -456,6 +482,11 @@ Filename = Annotated[
     pydantic.Field(default=None),
 ]
 
+Int = Annotated[
+    int,
+    functional_validators.BeforeValidator(to_int),
+]
+
 OptionalURL = Annotated[
     pydantic.HttpUrl | None,
     functional_validators.BeforeValidator(to_optional_url),
@@ -468,14 +499,6 @@ StrList = Annotated[
     pydantic.Field(default_factory=list),
 ]
 
-Int = Annotated[
-    int,
-    functional_validators.BeforeValidator(to_int),
-]
-
-
-URL = pydantic.HttpUrl
-
 
 class Set[EnumType: enum.Enum]:
     def __iter__(self) -> Iterator[EnumType]:
@@ -494,6 +517,9 @@ class Set[EnumType: enum.Enum]:
         ]
 
 
+URL = pydantic.HttpUrl
+
+
 def validate(model_cls: Any, form: dict[str, Any], context: dict[str, Any] | 
None = None) -> pydantic.BaseModel:
     # Since pydantic.TypeAdapter accepts Any, we do the same
     return pydantic.TypeAdapter(model_cls).validate_python(form, 
context=context)
@@ -657,7 +683,11 @@ def _render_widget(  # noqa: C901
                 selected_value = field_value[0] if field_value else None
             else:
                 choices = _get_choices(field_info)
-                selected_value = field_value
+                # If field_value is an enum, extract its value for comparison
+                if isinstance(field_value, enum.Enum):
+                    selected_value = field_value.value
+                else:
+                    selected_value = field_value
 
             options = [
                 htpy.option(
@@ -695,24 +725,38 @@ def _render_widget(  # noqa: C901
     return htm.div[elements] if len(elements) > 1 else elements[0]
 
 
-def _get_choices(field_info: pydantic.fields.FieldInfo) -> list[tuple[str, 
str]]:
+def _get_choices(field_info: pydantic.fields.FieldInfo) -> list[tuple[str, 
str]]:  # noqa: C901
     annotation = field_info.annotation
     origin = get_origin(annotation)
 
     if origin is Literal:
         return [(v, v) for v in get_args(annotation)]
 
+    if origin is Annotated:
+        # Check whether this is an Enum[T] or Set[T] annotation
+        args = get_args(annotation)
+        if args:
+            inner_type = args[0]
+            if isinstance(inner_type, type) and issubclass(inner_type, 
enum.Enum):
+                # This is an enum type wrapped in Annotated, from Enum[T] or 
Set[T]
+                return [(member.value, member.value) for member in inner_type]
+
     if origin is set:
         args = get_args(annotation)
-        if args and hasattr(args[0], "__members__"):
+        if args:
             enum_class = args[0]
-            return [(member.value, member.value) for member in enum_class]
+            if isinstance(enum_class, type) and issubclass(enum_class, 
enum.Enum):
+                return [(member.value, member.value) for member in enum_class]
 
     if origin is list:
         args = get_args(annotation)
         if args and get_origin(args[0]) is Literal:
             return [(v, v) for v in get_args(args[0])]
 
+    # Check for plain enum types, e.g. when Pydantic unwraps form.Enum[T]
+    if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
+        return [(member.value, member.value) for member in annotation]
+
     return []
 
 
@@ -779,8 +823,10 @@ def _get_widget_type(field_info: 
pydantic.fields.FieldInfo) -> Widget:  # noqa:
 
     if origin is set:
         args = get_args(annotation)
-        if args and hasattr(args[0], "__members__"):
-            return Widget.CHECKBOXES
+        if args:
+            first_arg = args[0]
+            if isinstance(first_arg, type) and issubclass(first_arg, 
enum.Enum):
+                return Widget.CHECKBOXES
 
     if origin is list:
         args = get_args(annotation)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to