https://github.com/python/cpython/commit/aaed91cabcedc16c089c4b1c9abb1114659a83d3
commit: aaed91cabcedc16c089c4b1c9abb1114659a83d3
branch: main
author: Ethan Furman <[email protected]>
committer: ethanfurman <[email protected]>
date: 2024-10-22T11:04:00-07:00
summary:

gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values 
(GH-125735)

files:
A Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst
M Lib/enum.py
M Lib/test/test_enum.py

diff --git a/Lib/enum.py b/Lib/enum.py
index 17d72738792982..4f9912229603a6 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -327,6 +327,8 @@ def __set_name__(self, enum_class, member_name):
             # to the map, and by-value lookups for this value will be
             # linear.
             enum_class._value2member_map_.setdefault(value, enum_member)
+            if value not in enum_class._hashable_values_:
+                enum_class._hashable_values_.append(value)
         except TypeError:
             # keep track of the value in a list so containment checks are quick
             enum_class._unhashable_values_.append(value)
@@ -538,7 +540,8 @@ def __new__(metacls, cls, bases, classdict, *, 
boundary=None, _simple=False, **k
         classdict['_member_names_'] = []
         classdict['_member_map_'] = {}
         classdict['_value2member_map_'] = {}
-        classdict['_unhashable_values_'] = []
+        classdict['_hashable_values_'] = []          # for comparing with 
non-hashable types
+        classdict['_unhashable_values_'] = []       # e.g. frozenset() with 
set()
         classdict['_unhashable_values_map_'] = {}
         classdict['_member_type_'] = member_type
         # now set the __repr__ for the value
@@ -748,7 +751,10 @@ def __contains__(cls, value):
         try:
             return value in cls._value2member_map_
         except TypeError:
-            return value in cls._unhashable_values_
+            return (
+                    value in cls._unhashable_values_    # both structures are 
lists
+                    or value in cls._hashable_values_
+                    )
 
     def __delattr__(cls, attr):
         # nicer error message when someone tries to delete an attribute
@@ -1166,8 +1172,11 @@ def __new__(cls, value):
             pass
         except TypeError:
             # not there, now do long search -- O(n) behavior
-            for name, values in cls._unhashable_values_map_.items():
-                if value in values:
+            for name, unhashable_values in cls._unhashable_values_map_.items():
+                if value in unhashable_values:
+                    return cls[name]
+            for name, member in cls._member_map_.items():
+                if value == member._value_:
                     return cls[name]
         # still not found -- verify that members exist, in-case somebody got 
here mistakenly
         # (such as via super when trying to override __new__)
@@ -1233,6 +1242,7 @@ def _add_value_alias_(self, value):
             # to the map, and by-value lookups for this value will be
             # linear.
             cls._value2member_map_.setdefault(value, self)
+            cls._hashable_values_.append(value)
         except TypeError:
             # keep track of the value in a list so containment checks are quick
             cls._unhashable_values_.append(value)
@@ -1763,6 +1773,7 @@ def convert_class(cls):
         body['_member_names_'] = member_names = []
         body['_member_map_'] = member_map = {}
         body['_value2member_map_'] = value2member_map = {}
+        body['_hashable_values_'] = hashable_values = []
         body['_unhashable_values_'] = unhashable_values = []
         body['_unhashable_values_map_'] = {}
         body['_member_type_'] = member_type = etype._member_type_
@@ -1826,7 +1837,7 @@ def convert_class(cls):
                     contained = value2member_map.get(member._value_)
                 except TypeError:
                     contained = None
-                    if member._value_ in unhashable_values:
+                    if member._value_ in unhashable_values or member.value in 
hashable_values:
                         for m in enum_class:
                             if m._value_ == member._value_:
                                 contained = m
@@ -1846,6 +1857,7 @@ def convert_class(cls):
                     else:
                         enum_class._add_member_(name, member)
                     value2member_map[value] = member
+                    hashable_values.append(value)
                     if _is_single_bit(value):
                         # not a multi-bit alias, record in _member_names_ and 
_flag_mask_
                         member_names.append(name)
@@ -1882,7 +1894,7 @@ def convert_class(cls):
                     contained = value2member_map.get(member._value_)
                 except TypeError:
                     contained = None
-                    if member._value_ in unhashable_values:
+                    if member._value_ in unhashable_values or member._value_ 
in hashable_values:
                         for m in enum_class:
                             if m._value_ == member._value_:
                                 contained = m
@@ -1908,6 +1920,8 @@ def convert_class(cls):
                         # to the map, and by-value lookups for this value will 
be
                         # linear.
                         enum_class._value2member_map_.setdefault(value, member)
+                        if value not in hashable_values:
+                            hashable_values.append(value)
                     except TypeError:
                         # keep track of the value in a list so containment 
checks are quick
                         enum_class._unhashable_values_.append(value)
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 5b4a8070526fcf..7184769bfd6fc3 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -3460,6 +3460,13 @@ def test_empty_names(self):
         self.assertRaisesRegex(TypeError, '.int. object is not iterable', 
Enum, 'bad_enum', names=0)
         self.assertRaisesRegex(TypeError, '.int. object is not iterable', 
Enum, 'bad_enum', 0, type=int)
 
+    def test_nonhashable_matches_hashable(self):    # issue 125710
+        class Directions(Enum):
+            DOWN_ONLY = frozenset({"sc"})
+            UP_ONLY = frozenset({"cs"})
+            UNRESTRICTED = frozenset({"sc", "cs"})
+        self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)
+
 
 class TestOrder(unittest.TestCase):
     "test usage of the `_order_` attribute"
diff --git 
a/Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst 
b/Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst
new file mode 100644
index 00000000000000..8d5220e9889c3a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-10-19-13-37-37.gh-issue-125710.FyFAAr.rst
@@ -0,0 +1 @@
+[Enum] fix hashable<->nonhashable comparisons for member values

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to