https://github.com/python/cpython/commit/e4fd5d542a8d7b735cce6e551a1b33bbd42e1d78
commit: e4fd5d542a8d7b735cce6e551a1b33bbd42e1d78
branch: 3.12
author: Miss Islington (bot) <[email protected]>
committer: ethanfurman <[email protected]>
date: 2024-02-19T16:18:40-08:00
summary:

[3.12] gh-115539: Allow enum.Flag to have None members (GH-115636) (GH-115694)

gh-115539: Allow enum.Flag to have None members (GH-115636)
(cherry picked from commit c2cb31bbe1262213085c425bc853d6587c66cae9)

Co-authored-by: Jason Zhang <[email protected]>

files:
M Lib/enum.py
M Lib/test/test_enum.py

diff --git a/Lib/enum.py b/Lib/enum.py
index b9c7f9cead9706..6d3bbdc31625f7 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -283,9 +283,10 @@ def __set_name__(self, enum_class, member_name):
         enum_member._sort_order_ = len(enum_class._member_names_)
 
         if Flag is not None and issubclass(enum_class, Flag):
-            enum_class._flag_mask_ |= value
-            if _is_single_bit(value):
-                enum_class._singles_mask_ |= value
+            if isinstance(value, int):
+                enum_class._flag_mask_ |= value
+                if _is_single_bit(value):
+                    enum_class._singles_mask_ |= value
             enum_class._all_bits_ = 2 ** 
((enum_class._flag_mask_).bit_length()) - 1
 
         # If another member with the same value was already defined, the
@@ -313,6 +314,7 @@ def __set_name__(self, enum_class, member_name):
             elif (
                     Flag is not None
                     and issubclass(enum_class, Flag)
+                    and isinstance(value, int)
                     and _is_single_bit(value)
                 ):
                 # no other instances found, record this member in 
_member_names_
@@ -1534,37 +1536,50 @@ def __str__(self):
     def __bool__(self):
         return bool(self._value_)
 
+    def _get_value(self, flag):
+        if isinstance(flag, self.__class__):
+            return flag._value_
+        elif self._member_type_ is not object and isinstance(flag, 
self._member_type_):
+            return flag
+        return NotImplemented
+
     def __or__(self, other):
-        if isinstance(other, self.__class__):
-            other = other._value_
-        elif self._member_type_ is not object and isinstance(other, 
self._member_type_):
-            other = other
-        else:
+        other_value = self._get_value(other)
+        if other_value is NotImplemented:
             return NotImplemented
+
+        for flag in self, other:
+            if self._get_value(flag) is None:
+                raise TypeError(f"'{flag}' cannot be combined with other flags 
with |")
         value = self._value_
-        return self.__class__(value | other)
+        return self.__class__(value | other_value)
 
     def __and__(self, other):
-        if isinstance(other, self.__class__):
-            other = other._value_
-        elif self._member_type_ is not object and isinstance(other, 
self._member_type_):
-            other = other
-        else:
+        other_value = self._get_value(other)
+        if other_value is NotImplemented:
             return NotImplemented
+
+        for flag in self, other:
+            if self._get_value(flag) is None:
+                raise TypeError(f"'{flag}' cannot be combined with other flags 
with &")
         value = self._value_
-        return self.__class__(value & other)
+        return self.__class__(value & other_value)
 
     def __xor__(self, other):
-        if isinstance(other, self.__class__):
-            other = other._value_
-        elif self._member_type_ is not object and isinstance(other, 
self._member_type_):
-            other = other
-        else:
+        other_value = self._get_value(other)
+        if other_value is NotImplemented:
             return NotImplemented
+
+        for flag in self, other:
+            if self._get_value(flag) is None:
+                raise TypeError(f"'{flag}' cannot be combined with other flags 
with ^")
         value = self._value_
-        return self.__class__(value ^ other)
+        return self.__class__(value ^ other_value)
 
     def __invert__(self):
+        if self._get_value(self) is None:
+            raise TypeError(f"'{self}' cannot be inverted")
+
         if self._inverted_ is None:
             if self._boundary_ in (EJECT, KEEP):
                 self._inverted_ = self.__class__(~self._value_)
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index ec082c46d10405..c58dc36fe84134 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -1007,6 +1007,22 @@ class TestPlainEnumFunction(_EnumTests, 
_PlainOutputTests, unittest.TestCase):
 class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, 
unittest.TestCase):
     enum_type = Flag
 
+    def test_none_member(self):
+        class FlagWithNoneMember(Flag):
+            A = 1
+            E = None
+
+        self.assertEqual(FlagWithNoneMember.A.value, 1)
+        self.assertIs(FlagWithNoneMember.E.value, None)
+        with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot 
be combined with other flags with |"):
+            FlagWithNoneMember.A | FlagWithNoneMember.E
+        with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot 
be combined with other flags with &"):
+            FlagWithNoneMember.E & FlagWithNoneMember.A
+        with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot 
be combined with other flags with \^"):
+            FlagWithNoneMember.A ^ FlagWithNoneMember.E
+        with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot 
be inverted"):
+            ~FlagWithNoneMember.E
+
 
 class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, 
unittest.TestCase):
     enum_type = Flag

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to