mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	gh-103365: [Enum] STRICT boundary corrections (GH-103494)
STRICT boundary: - fix bitwise operations - make default for Flag
This commit is contained in:
		
							parent
							
								
									efb8a2553c
								
							
						
					
					
						commit
						2194071540
					
				
					 4 changed files with 82 additions and 38 deletions
				
			
		|  | @ -696,7 +696,8 @@ Data Types | ||||||
| 
 | 
 | ||||||
|    .. attribute:: STRICT |    .. attribute:: STRICT | ||||||
| 
 | 
 | ||||||
|       Out-of-range values cause a :exc:`ValueError` to be raised:: |       Out-of-range values cause a :exc:`ValueError` to be raised. This is the | ||||||
|  |       default for :class:`Flag`:: | ||||||
| 
 | 
 | ||||||
|          >>> from enum import Flag, STRICT, auto |          >>> from enum import Flag, STRICT, auto | ||||||
|          >>> class StrictFlag(Flag, boundary=STRICT): |          >>> class StrictFlag(Flag, boundary=STRICT): | ||||||
|  | @ -714,7 +715,7 @@ Data Types | ||||||
|    .. attribute:: CONFORM |    .. attribute:: CONFORM | ||||||
| 
 | 
 | ||||||
|       Out-of-range values have invalid values removed, leaving a valid *Flag* |       Out-of-range values have invalid values removed, leaving a valid *Flag* | ||||||
|       value. This is the default for :class:`Flag`:: |       value:: | ||||||
| 
 | 
 | ||||||
|          >>> from enum import Flag, CONFORM, auto |          >>> from enum import Flag, CONFORM, auto | ||||||
|          >>> class ConformFlag(Flag, boundary=CONFORM): |          >>> class ConformFlag(Flag, boundary=CONFORM): | ||||||
|  |  | ||||||
							
								
								
									
										67
									
								
								Lib/enum.py
									
										
									
									
									
								
							
							
						
						
									
										67
									
								
								Lib/enum.py
									
										
									
									
									
								
							|  | @ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name): | ||||||
|         enum_member.__objclass__ = enum_class |         enum_member.__objclass__ = enum_class | ||||||
|         enum_member.__init__(*args) |         enum_member.__init__(*args) | ||||||
|         enum_member._sort_order_ = len(enum_class._member_names_) |         enum_member._sort_order_ = len(enum_class._member_names_) | ||||||
|  | 
 | ||||||
|  |         if Flag is not None and issubclass(enum_class, Flag): | ||||||
|  |             enum_class._flag_mask_ |= value | ||||||
|  |             if _is_single_bit(value): | ||||||
|  |                 enum_class._singles_mask_ |= value | ||||||
|  |             enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 | ||||||
|  | 
 | ||||||
|         # If another member with the same value was already defined, the |         # If another member with the same value was already defined, the | ||||||
|         # new member becomes an alias to the existing one. |         # new member becomes an alias to the existing one. | ||||||
|         try: |         try: | ||||||
|  | @ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k | ||||||
|         classdict['_use_args_'] = use_args |         classdict['_use_args_'] = use_args | ||||||
|         # |         # | ||||||
|         # convert future enum members into temporary _proto_members |         # convert future enum members into temporary _proto_members | ||||||
|         # and record integer values in case this will be a Flag |  | ||||||
|         flag_mask = 0 |  | ||||||
|         for name in member_names: |         for name in member_names: | ||||||
|             value = classdict[name] |             value = classdict[name] | ||||||
|             if isinstance(value, int): |  | ||||||
|                 flag_mask |= value |  | ||||||
|             classdict[name] = _proto_member(value) |             classdict[name] = _proto_member(value) | ||||||
|         # |         # | ||||||
|         # house-keeping structures |         # house-keeping structures | ||||||
|  | @ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k | ||||||
|                 boundary |                 boundary | ||||||
|                 or getattr(first_enum, '_boundary_', None) |                 or getattr(first_enum, '_boundary_', None) | ||||||
|                 ) |                 ) | ||||||
|         classdict['_flag_mask_'] = flag_mask |         classdict['_flag_mask_'] = 0 | ||||||
|         classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1 |         classdict['_singles_mask_'] = 0 | ||||||
|  |         classdict['_all_bits_'] = 0 | ||||||
|         classdict['_inverted_'] = None |         classdict['_inverted_'] = None | ||||||
|         try: |         try: | ||||||
|             exc = None |             exc = None | ||||||
|  | @ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k | ||||||
|             ): |             ): | ||||||
|             delattr(enum_class, '_boundary_') |             delattr(enum_class, '_boundary_') | ||||||
|             delattr(enum_class, '_flag_mask_') |             delattr(enum_class, '_flag_mask_') | ||||||
|  |             delattr(enum_class, '_singles_mask_') | ||||||
|             delattr(enum_class, '_all_bits_') |             delattr(enum_class, '_all_bits_') | ||||||
|             delattr(enum_class, '_inverted_') |             delattr(enum_class, '_inverted_') | ||||||
|         elif Flag is not None and issubclass(enum_class, Flag): |         elif Flag is not None and issubclass(enum_class, Flag): | ||||||
|             # ensure _all_bits_ is correct and there are no missing flags |  | ||||||
|             single_bit_total = 0 |  | ||||||
|             multi_bit_total = 0 |  | ||||||
|             for flag in enum_class._member_map_.values(): |  | ||||||
|                 flag_value = flag._value_ |  | ||||||
|                 if _is_single_bit(flag_value): |  | ||||||
|                     single_bit_total |= flag_value |  | ||||||
|                 else: |  | ||||||
|                     # multi-bit flags are considered aliases |  | ||||||
|                     multi_bit_total |= flag_value |  | ||||||
|             enum_class._flag_mask_ = single_bit_total |  | ||||||
|             # |  | ||||||
|             # set correct __iter__ |             # set correct __iter__ | ||||||
|             member_list = [m._value_ for m in enum_class] |             member_list = [m._value_ for m in enum_class] | ||||||
|             if member_list != sorted(member_list): |             if member_list != sorted(member_list): | ||||||
|  | @ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto): | ||||||
| class FlagBoundary(StrEnum): | class FlagBoundary(StrEnum): | ||||||
|     """ |     """ | ||||||
|     control how out of range values are handled |     control how out of range values are handled | ||||||
|     "strict" -> error is raised |     "strict" -> error is raised             [default for Flag] | ||||||
|     "conform" -> extra bits are discarded   [default for Flag] |     "conform" -> extra bits are discarded | ||||||
|     "eject" -> lose flag status |     "eject" -> lose flag status | ||||||
|     "keep" -> keep flag status and all bits [default for IntFlag] |     "keep" -> keep flag status and all bits [default for IntFlag] | ||||||
|     """ |     """ | ||||||
|  | @ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum): | ||||||
| STRICT, CONFORM, EJECT, KEEP = FlagBoundary | STRICT, CONFORM, EJECT, KEEP = FlagBoundary | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Flag(Enum, boundary=CONFORM): | class Flag(Enum, boundary=STRICT): | ||||||
|     """ |     """ | ||||||
|     Support for flags |     Support for flags | ||||||
|     """ |     """ | ||||||
|  | @ -1394,6 +1387,7 @@ def _missing_(cls, value): | ||||||
|         # - value must not include any skipped flags (e.g. if bit 2 is not |         # - value must not include any skipped flags (e.g. if bit 2 is not | ||||||
|         #   defined, then 0d10 is invalid) |         #   defined, then 0d10 is invalid) | ||||||
|         flag_mask = cls._flag_mask_ |         flag_mask = cls._flag_mask_ | ||||||
|  |         singles_mask = cls._singles_mask_ | ||||||
|         all_bits = cls._all_bits_ |         all_bits = cls._all_bits_ | ||||||
|         neg_value = None |         neg_value = None | ||||||
|         if ( |         if ( | ||||||
|  | @ -1425,7 +1419,8 @@ def _missing_(cls, value): | ||||||
|             value = all_bits + 1 + value |             value = all_bits + 1 + value | ||||||
|         # get members and unknown |         # get members and unknown | ||||||
|         unknown = value & ~flag_mask |         unknown = value & ~flag_mask | ||||||
|         member_value = value & flag_mask |         aliases = value & ~singles_mask | ||||||
|  |         member_value = value & singles_mask | ||||||
|         if unknown and cls._boundary_ is not KEEP: |         if unknown and cls._boundary_ is not KEEP: | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|                     '%s(%r) -->  unknown values %r [%s]' |                     '%s(%r) -->  unknown values %r [%s]' | ||||||
|  | @ -1439,11 +1434,25 @@ def _missing_(cls, value): | ||||||
|             pseudo_member = cls._member_type_.__new__(cls, value) |             pseudo_member = cls._member_type_.__new__(cls, value) | ||||||
|         if not hasattr(pseudo_member, '_value_'): |         if not hasattr(pseudo_member, '_value_'): | ||||||
|             pseudo_member._value_ = value |             pseudo_member._value_ = value | ||||||
|         if member_value: |         if member_value or aliases: | ||||||
|             pseudo_member._name_ = '|'.join([ |             members = [] | ||||||
|                 m._name_ for m in cls._iter_member_(member_value) |             combined_value = 0 | ||||||
|                 ]) |             for m in cls._iter_member_(member_value): | ||||||
|             if unknown: |                 members.append(m) | ||||||
|  |                 combined_value |= m._value_ | ||||||
|  |             if aliases: | ||||||
|  |                 value = member_value | aliases | ||||||
|  |                 for n, pm in cls._member_map_.items(): | ||||||
|  |                     if pm not in members and pm._value_ and pm._value_ & value == pm._value_: | ||||||
|  |                         members.append(pm) | ||||||
|  |                         combined_value |= pm._value_ | ||||||
|  |             unknown = value ^ combined_value | ||||||
|  |             pseudo_member._name_ = '|'.join([m._name_ for m in members]) | ||||||
|  |             if not combined_value: | ||||||
|  |                 pseudo_member._name_ = None | ||||||
|  |             elif unknown and cls._boundary_ is STRICT: | ||||||
|  |                 raise ValueError('%r: no members with value %r' % (cls, unknown)) | ||||||
|  |             elif unknown: | ||||||
|                 pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown) |                 pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown) | ||||||
|         else: |         else: | ||||||
|             pseudo_member._name_ = None |             pseudo_member._name_ = None | ||||||
|  | @ -1675,6 +1684,7 @@ def convert_class(cls): | ||||||
|             body['_boundary_'] = boundary or etype._boundary_ |             body['_boundary_'] = boundary or etype._boundary_ | ||||||
|             body['_flag_mask_'] = None |             body['_flag_mask_'] = None | ||||||
|             body['_all_bits_'] = None |             body['_all_bits_'] = None | ||||||
|  |             body['_singles_mask_'] = None | ||||||
|             body['_inverted_'] = None |             body['_inverted_'] = None | ||||||
|             body['__or__'] = Flag.__or__ |             body['__or__'] = Flag.__or__ | ||||||
|             body['__xor__'] = Flag.__xor__ |             body['__xor__'] = Flag.__xor__ | ||||||
|  | @ -1750,7 +1760,8 @@ def convert_class(cls): | ||||||
|                     else: |                     else: | ||||||
|                         multi_bits |= value |                         multi_bits |= value | ||||||
|                     gnv_last_values.append(value) |                     gnv_last_values.append(value) | ||||||
|             enum_class._flag_mask_ = single_bits |             enum_class._flag_mask_ = single_bits | multi_bits | ||||||
|  |             enum_class._singles_mask_ = single_bits | ||||||
|             enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1 |             enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1 | ||||||
|             # set correct __iter__ |             # set correct __iter__ | ||||||
|             member_list = [m._value_ for m in enum_class] |             member_list = [m._value_ for m in enum_class] | ||||||
|  |  | ||||||
|  | @ -2873,6 +2873,8 @@ def __new__(cls, c): | ||||||
|             # |             # | ||||||
|             a = ord('a') |             a = ord('a') | ||||||
|         # |         # | ||||||
|  |         self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) | ||||||
|  |         self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) | ||||||
|         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) |         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) | ||||||
|         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) |         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) | ||||||
|         # |         # | ||||||
|  | @ -2887,6 +2889,8 @@ def __new__(cls, c): | ||||||
|             a = ord('a') |             a = ord('a') | ||||||
|             z = 1 |             z = 1 | ||||||
|         # |         # | ||||||
|  |         self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) | ||||||
|  |         self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674) | ||||||
|         self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672) |         self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672) | ||||||
|         self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674) |         self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674) | ||||||
|         # |         # | ||||||
|  | @ -2900,6 +2904,8 @@ def __new__(cls, c): | ||||||
|             # |             # | ||||||
|             a = ord('a') |             a = ord('a') | ||||||
|         # |         # | ||||||
|  |         self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) | ||||||
|  |         self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) | ||||||
|         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) |         self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) | ||||||
|         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) |         self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) | ||||||
| 
 | 
 | ||||||
|  | @ -3077,18 +3083,18 @@ def test_bool(self): | ||||||
|             self.assertEqual(bool(f.value), bool(f)) |             self.assertEqual(bool(f.value), bool(f)) | ||||||
| 
 | 
 | ||||||
|     def test_boundary(self): |     def test_boundary(self): | ||||||
|         self.assertIs(enum.Flag._boundary_, CONFORM) |         self.assertIs(enum.Flag._boundary_, STRICT) | ||||||
|         class Iron(Flag, boundary=STRICT): |         class Iron(Flag, boundary=CONFORM): | ||||||
|             ONE = 1 |             ONE = 1 | ||||||
|             TWO = 2 |             TWO = 2 | ||||||
|             EIGHT = 8 |             EIGHT = 8 | ||||||
|         self.assertIs(Iron._boundary_, STRICT) |         self.assertIs(Iron._boundary_, CONFORM) | ||||||
|         # |         # | ||||||
|         class Water(Flag, boundary=CONFORM): |         class Water(Flag, boundary=STRICT): | ||||||
|             ONE = 1 |             ONE = 1 | ||||||
|             TWO = 2 |             TWO = 2 | ||||||
|             EIGHT = 8 |             EIGHT = 8 | ||||||
|         self.assertIs(Water._boundary_, CONFORM) |         self.assertIs(Water._boundary_, STRICT) | ||||||
|         # |         # | ||||||
|         class Space(Flag, boundary=EJECT): |         class Space(Flag, boundary=EJECT): | ||||||
|             ONE = 1 |             ONE = 1 | ||||||
|  | @ -3101,10 +3107,10 @@ class Bizarre(Flag, boundary=KEEP): | ||||||
|             c = 4 |             c = 4 | ||||||
|             d = 6 |             d = 6 | ||||||
|         # |         # | ||||||
|         self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7) |         self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7) | ||||||
|         # |         # | ||||||
|         self.assertIs(Water(7), Water.ONE|Water.TWO) |         self.assertIs(Iron(7), Iron.ONE|Iron.TWO) | ||||||
|         self.assertIs(Water(~9), Water.TWO) |         self.assertIs(Iron(~9), Iron.TWO) | ||||||
|         # |         # | ||||||
|         self.assertEqual(Space(7), 7) |         self.assertEqual(Space(7), 7) | ||||||
|         self.assertTrue(type(Space(7)) is int) |         self.assertTrue(type(Space(7)) is int) | ||||||
|  | @ -3112,6 +3118,31 @@ class Bizarre(Flag, boundary=KEEP): | ||||||
|         self.assertEqual(list(Bizarre), [Bizarre.c]) |         self.assertEqual(list(Bizarre), [Bizarre.c]) | ||||||
|         self.assertIs(Bizarre(3), Bizarre.b) |         self.assertIs(Bizarre(3), Bizarre.b) | ||||||
|         self.assertIs(Bizarre(6), Bizarre.d) |         self.assertIs(Bizarre(6), Bizarre.d) | ||||||
|  |         # | ||||||
|  |         class SkipFlag(enum.Flag): | ||||||
|  |             A = 1 | ||||||
|  |             B = 2 | ||||||
|  |             C = 4 | B | ||||||
|  |         # | ||||||
|  |         self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C)) | ||||||
|  |         self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42) | ||||||
|  |         # | ||||||
|  |         class SkipIntFlag(enum.IntFlag): | ||||||
|  |             A = 1 | ||||||
|  |             B = 2 | ||||||
|  |             C = 4 | B | ||||||
|  |         # | ||||||
|  |         self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C)) | ||||||
|  |         self.assertEqual(SkipIntFlag(42).value, 42) | ||||||
|  |         # | ||||||
|  |         class MethodHint(Flag): | ||||||
|  |             HiddenText = 0x10 | ||||||
|  |             DigitsOnly = 0x01 | ||||||
|  |             LettersOnly = 0x02 | ||||||
|  |             OnlyMask = 0x0f | ||||||
|  |         # | ||||||
|  |         self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask') | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     def test_iter(self): |     def test_iter(self): | ||||||
|         Color = self.Color |         Color = self.Color | ||||||
|  |  | ||||||
|  | @ -0,0 +1 @@ | ||||||
|  | Set default Flag boundary to ``STRICT`` and fix bitwise operations. | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Ethan Furman
						Ethan Furman