[Python-checkins] gh-103365: [Enum] STRICT boundary corrections (GH-103494)

ethanfurman webhook-mailer at python.org
Thu Apr 13 11:24:40 EDT 2023


https://github.com/python/cpython/commit/2194071540313e2bbdc7214d77453b9ce3034a5c
commit: 2194071540313e2bbdc7214d77453b9ce3034a5c
branch: main
author: Ethan Furman <ethan at stoneleaf.us>
committer: ethanfurman <ethan at stoneleaf.us>
date: 2023-04-13T08:24:33-07:00
summary:

gh-103365: [Enum] STRICT boundary corrections (GH-103494)

STRICT boundary:

- fix bitwise operations
- make default for Flag

files:
A Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst
M Doc/library/enum.rst
M Lib/enum.py
M Lib/test/test_enum.py

diff --git a/Doc/library/enum.rst b/Doc/library/enum.rst
index c690c837309e..07acf9da33e2 100644
--- a/Doc/library/enum.rst
+++ b/Doc/library/enum.rst
@@ -696,7 +696,8 @@ Data Types
 
    .. attribute:: STRICT
 
-      Out-of-range values cause a :exc:`ValueError` to be raised::
+      Out-of-range values cause a :exc:`ValueError` to be raised. This is the
+      default for :class:`Flag`::
 
          >>> from enum import Flag, STRICT, auto
          >>> class StrictFlag(Flag, boundary=STRICT):
@@ -714,7 +715,7 @@ Data Types
    .. attribute:: CONFORM
 
       Out-of-range values have invalid values removed, leaving a valid *Flag*
-      value. This is the default for :class:`Flag`::
+      value::
 
          >>> from enum import Flag, CONFORM, auto
          >>> class ConformFlag(Flag, boundary=CONFORM):
diff --git a/Lib/enum.py b/Lib/enum.py
index 10902c4b202a..432d7456b4b9 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name):
         enum_member.__objclass__ = enum_class
         enum_member.__init__(*args)
         enum_member._sort_order_ = len(enum_class._member_names_)
+
+        if Flag is not None and issubclass(enum_class, Flag):
+            enum_class._flag_mask_ |= value
+            if _is_single_bit(value):
+                enum_class._singles_mask_ |= value
+            enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
+
         # If another member with the same value was already defined, the
         # new member becomes an alias to the existing one.
         try:
@@ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
         classdict['_use_args_'] = use_args
         #
         # convert future enum members into temporary _proto_members
-        # and record integer values in case this will be a Flag
-        flag_mask = 0
         for name in member_names:
             value = classdict[name]
-            if isinstance(value, int):
-                flag_mask |= value
             classdict[name] = _proto_member(value)
         #
         # house-keeping structures
@@ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
                 boundary
                 or getattr(first_enum, '_boundary_', None)
                 )
-        classdict['_flag_mask_'] = flag_mask
-        classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
+        classdict['_flag_mask_'] = 0
+        classdict['_singles_mask_'] = 0
+        classdict['_all_bits_'] = 0
         classdict['_inverted_'] = None
         try:
             exc = None
@@ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
             ):
             delattr(enum_class, '_boundary_')
             delattr(enum_class, '_flag_mask_')
+            delattr(enum_class, '_singles_mask_')
             delattr(enum_class, '_all_bits_')
             delattr(enum_class, '_inverted_')
         elif Flag is not None and issubclass(enum_class, Flag):
-            # ensure _all_bits_ is correct and there are no missing flags
-            single_bit_total = 0
-            multi_bit_total = 0
-            for flag in enum_class._member_map_.values():
-                flag_value = flag._value_
-                if _is_single_bit(flag_value):
-                    single_bit_total |= flag_value
-                else:
-                    # multi-bit flags are considered aliases
-                    multi_bit_total |= flag_value
-            enum_class._flag_mask_ = single_bit_total
-            #
             # set correct __iter__
             member_list = [m._value_ for m in enum_class]
             if member_list != sorted(member_list):
@@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
 class FlagBoundary(StrEnum):
     """
     control how out of range values are handled
-    "strict" -> error is raised
-    "conform" -> extra bits are discarded   [default for Flag]
+    "strict" -> error is raised             [default for Flag]
+    "conform" -> extra bits are discarded
     "eject" -> lose flag status
     "keep" -> keep flag status and all bits [default for IntFlag]
     """
@@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
 STRICT, CONFORM, EJECT, KEEP = FlagBoundary
 
 
-class Flag(Enum, boundary=CONFORM):
+class Flag(Enum, boundary=STRICT):
     """
     Support for flags
     """
@@ -1394,6 +1387,7 @@ def _missing_(cls, value):
         # - value must not include any skipped flags (e.g. if bit 2 is not
         #   defined, then 0d10 is invalid)
         flag_mask = cls._flag_mask_
+        singles_mask = cls._singles_mask_
         all_bits = cls._all_bits_
         neg_value = None
         if (
@@ -1425,7 +1419,8 @@ def _missing_(cls, value):
             value = all_bits + 1 + value
         # get members and unknown
         unknown = value & ~flag_mask
-        member_value = value & flag_mask
+        aliases = value & ~singles_mask
+        member_value = value & singles_mask
         if unknown and cls._boundary_ is not KEEP:
             raise ValueError(
                     '%s(%r) -->  unknown values %r [%s]'
@@ -1439,11 +1434,25 @@ def _missing_(cls, value):
             pseudo_member = cls._member_type_.__new__(cls, value)
         if not hasattr(pseudo_member, '_value_'):
             pseudo_member._value_ = value
-        if member_value:
-            pseudo_member._name_ = '|'.join([
-                m._name_ for m in cls._iter_member_(member_value)
-                ])
-            if unknown:
+        if member_value or aliases:
+            members = []
+            combined_value = 0
+            for m in cls._iter_member_(member_value):
+                members.append(m)
+                combined_value |= m._value_
+            if aliases:
+                value = member_value | aliases
+                for n, pm in cls._member_map_.items():
+                    if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
+                        members.append(pm)
+                        combined_value |= pm._value_
+            unknown = value ^ combined_value
+            pseudo_member._name_ = '|'.join([m._name_ for m in members])
+            if not combined_value:
+                pseudo_member._name_ = None
+            elif unknown and cls._boundary_ is STRICT:
+                raise ValueError('%r: no members with value %r' % (cls, unknown))
+            elif unknown:
                 pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
         else:
             pseudo_member._name_ = None
@@ -1675,6 +1684,7 @@ def convert_class(cls):
             body['_boundary_'] = boundary or etype._boundary_
             body['_flag_mask_'] = None
             body['_all_bits_'] = None
+            body['_singles_mask_'] = None
             body['_inverted_'] = None
             body['__or__'] = Flag.__or__
             body['__xor__'] = Flag.__xor__
@@ -1750,7 +1760,8 @@ def convert_class(cls):
                     else:
                         multi_bits |= value
                     gnv_last_values.append(value)
-            enum_class._flag_mask_ = single_bits
+            enum_class._flag_mask_ = single_bits | multi_bits
+            enum_class._singles_mask_ = single_bits
             enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
             # set correct __iter__
             member_list = [m._value_ for m in enum_class]
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index e4151bf9e6d9..89294e95df2a 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2873,6 +2873,8 @@ def __new__(cls, c):
             #
             a = ord('a')
         #
+        self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
+        self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
         #
@@ -2887,6 +2889,8 @@ def __new__(cls, c):
             a = ord('a')
             z = 1
         #
+        self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
+        self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674)
         self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672)
         self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674)
         #
@@ -2900,6 +2904,8 @@ def __new__(cls, c):
             #
             a = ord('a')
         #
+        self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
+        self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
 
@@ -3077,18 +3083,18 @@ def test_bool(self):
             self.assertEqual(bool(f.value), bool(f))
 
     def test_boundary(self):
-        self.assertIs(enum.Flag._boundary_, CONFORM)
-        class Iron(Flag, boundary=STRICT):
+        self.assertIs(enum.Flag._boundary_, STRICT)
+        class Iron(Flag, boundary=CONFORM):
             ONE = 1
             TWO = 2
             EIGHT = 8
-        self.assertIs(Iron._boundary_, STRICT)
+        self.assertIs(Iron._boundary_, CONFORM)
         #
-        class Water(Flag, boundary=CONFORM):
+        class Water(Flag, boundary=STRICT):
             ONE = 1
             TWO = 2
             EIGHT = 8
-        self.assertIs(Water._boundary_, CONFORM)
+        self.assertIs(Water._boundary_, STRICT)
         #
         class Space(Flag, boundary=EJECT):
             ONE = 1
@@ -3101,10 +3107,10 @@ class Bizarre(Flag, boundary=KEEP):
             c = 4
             d = 6
         #
-        self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7)
+        self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7)
         #
-        self.assertIs(Water(7), Water.ONE|Water.TWO)
-        self.assertIs(Water(~9), Water.TWO)
+        self.assertIs(Iron(7), Iron.ONE|Iron.TWO)
+        self.assertIs(Iron(~9), Iron.TWO)
         #
         self.assertEqual(Space(7), 7)
         self.assertTrue(type(Space(7)) is int)
@@ -3112,6 +3118,31 @@ class Bizarre(Flag, boundary=KEEP):
         self.assertEqual(list(Bizarre), [Bizarre.c])
         self.assertIs(Bizarre(3), Bizarre.b)
         self.assertIs(Bizarre(6), Bizarre.d)
+        #
+        class SkipFlag(enum.Flag):
+            A = 1
+            B = 2
+            C = 4 | B
+        #
+        self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C))
+        self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42)
+        #
+        class SkipIntFlag(enum.IntFlag):
+            A = 1
+            B = 2
+            C = 4 | B
+        #
+        self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C))
+        self.assertEqual(SkipIntFlag(42).value, 42)
+        #
+        class MethodHint(Flag):
+            HiddenText = 0x10
+            DigitsOnly = 0x01
+            LettersOnly = 0x02
+            OnlyMask = 0x0f
+        #
+        self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask')
+
 
     def test_iter(self):
         Color = self.Color
diff --git a/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst b/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst
new file mode 100644
index 000000000000..4d69f6f6fff7
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst
@@ -0,0 +1 @@
+Set default Flag boundary to ``STRICT`` and fix bitwise operations.



More information about the Python-checkins mailing list