mirror of
				https://github.com/python/cpython.git
				synced 2025-10-30 21:21:22 +00:00 
			
		
		
		
	gh-115539: Allow enum.Flag to have None members (GH-115636)
This commit is contained in:
		
							parent
							
								
									6cd18c75a4
								
							
						
					
					
						commit
						c2cb31bbe1
					
				
					 2 changed files with 52 additions and 21 deletions
				
			
		
							
								
								
									
										53
									
								
								Lib/enum.py
									
										
									
									
									
								
							
							
						
						
									
										53
									
								
								Lib/enum.py
									
										
									
									
									
								
							|  | @ -279,6 +279,7 @@ def __set_name__(self, enum_class, member_name): | |||
|         enum_member._sort_order_ = len(enum_class._member_names_) | ||||
| 
 | ||||
|         if Flag is not None and issubclass(enum_class, Flag): | ||||
|             if isinstance(value, int): | ||||
|                 enum_class._flag_mask_ |= value | ||||
|                 if _is_single_bit(value): | ||||
|                     enum_class._singles_mask_ |= value | ||||
|  | @ -309,6 +310,7 @@ def __set_name__(self, enum_class, member_name): | |||
|             elif ( | ||||
|                     Flag is not None | ||||
|                     and issubclass(enum_class, Flag) | ||||
|                     and isinstance(value, int) | ||||
|                     and _is_single_bit(value) | ||||
|                 ): | ||||
|                 # no other instances found, record this member in _member_names_ | ||||
|  | @ -1558,37 +1560,50 @@ def __str__(self): | |||
|     def __bool__(self): | ||||
|         return bool(self._value_) | ||||
| 
 | ||||
|     def __or__(self, other): | ||||
|         if isinstance(other, self.__class__): | ||||
|             other = other._value_ | ||||
|         elif self._member_type_ is not object and isinstance(other, self._member_type_): | ||||
|             other = other | ||||
|         else: | ||||
|     def _get_value(self, flag): | ||||
|         if isinstance(flag, self.__class__): | ||||
|             return flag._value_ | ||||
|         elif self._member_type_ is not object and isinstance(flag, self._member_type_): | ||||
|             return flag | ||||
|         return NotImplemented | ||||
| 
 | ||||
|     def __or__(self, other): | ||||
|         other_value = self._get_value(other) | ||||
|         if other_value is NotImplemented: | ||||
|             return NotImplemented | ||||
| 
 | ||||
|         for flag in self, other: | ||||
|             if self._get_value(flag) is None: | ||||
|                 raise TypeError(f"'{flag}' cannot be combined with other flags with |") | ||||
|         value = self._value_ | ||||
|         return self.__class__(value | other) | ||||
|         return self.__class__(value | other_value) | ||||
| 
 | ||||
|     def __and__(self, other): | ||||
|         if isinstance(other, self.__class__): | ||||
|             other = other._value_ | ||||
|         elif self._member_type_ is not object and isinstance(other, self._member_type_): | ||||
|             other = other | ||||
|         else: | ||||
|         other_value = self._get_value(other) | ||||
|         if other_value is NotImplemented: | ||||
|             return NotImplemented | ||||
| 
 | ||||
|         for flag in self, other: | ||||
|             if self._get_value(flag) is None: | ||||
|                 raise TypeError(f"'{flag}' cannot be combined with other flags with &") | ||||
|         value = self._value_ | ||||
|         return self.__class__(value & other) | ||||
|         return self.__class__(value & other_value) | ||||
| 
 | ||||
|     def __xor__(self, other): | ||||
|         if isinstance(other, self.__class__): | ||||
|             other = other._value_ | ||||
|         elif self._member_type_ is not object and isinstance(other, self._member_type_): | ||||
|             other = other | ||||
|         else: | ||||
|         other_value = self._get_value(other) | ||||
|         if other_value is NotImplemented: | ||||
|             return NotImplemented | ||||
| 
 | ||||
|         for flag in self, other: | ||||
|             if self._get_value(flag) is None: | ||||
|                 raise TypeError(f"'{flag}' cannot be combined with other flags with ^") | ||||
|         value = self._value_ | ||||
|         return self.__class__(value ^ other) | ||||
|         return self.__class__(value ^ other_value) | ||||
| 
 | ||||
|     def __invert__(self): | ||||
|         if self._get_value(self) is None: | ||||
|             raise TypeError(f"'{self}' cannot be inverted") | ||||
| 
 | ||||
|         if self._inverted_ is None: | ||||
|             if self._boundary_ in (EJECT, KEEP): | ||||
|                 self._inverted_ = self.__class__(~self._value_) | ||||
|  |  | |||
|  | @ -1048,6 +1048,22 @@ class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase): | |||
| class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): | ||||
|     enum_type = Flag | ||||
| 
 | ||||
|     def test_none_member(self): | ||||
|         class FlagWithNoneMember(Flag): | ||||
|             A = 1 | ||||
|             E = None | ||||
| 
 | ||||
|         self.assertEqual(FlagWithNoneMember.A.value, 1) | ||||
|         self.assertIs(FlagWithNoneMember.E.value, None) | ||||
|         with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"): | ||||
|             FlagWithNoneMember.A | FlagWithNoneMember.E | ||||
|         with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"): | ||||
|             FlagWithNoneMember.E & FlagWithNoneMember.A | ||||
|         with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"): | ||||
|             FlagWithNoneMember.A ^ FlagWithNoneMember.E | ||||
|         with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"): | ||||
|             ~FlagWithNoneMember.E | ||||
| 
 | ||||
| 
 | ||||
| class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): | ||||
|     enum_type = Flag | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jason Zhang
						Jason Zhang