mirror of
				https://github.com/python/cpython.git
				synced 2025-10-30 21:21:22 +00:00 
			
		
		
		
	[Enum] improve pickle support (#26666)
search all bases for a __reduce__ style method; if a __new__ method is found first the enum will be made unpicklable
This commit is contained in:
		
							parent
							
								
									3ce35bfbbe
								
							
						
					
					
						commit
						e9726314df
					
				
					 3 changed files with 65 additions and 3 deletions
				
			
		
							
								
								
									
										30
									
								
								Lib/enum.py
									
										
									
									
									
								
							
							
						
						
									
										30
									
								
								Lib/enum.py
									
										
									
									
									
								
							|  | @ -242,8 +242,32 @@ def __new__(metacls, cls, bases, classdict, **kwds): | ||||||
|                 methods = ('__getnewargs_ex__', '__getnewargs__', |                 methods = ('__getnewargs_ex__', '__getnewargs__', | ||||||
|                         '__reduce_ex__', '__reduce__') |                         '__reduce_ex__', '__reduce__') | ||||||
|                 if not any(m in member_type.__dict__ for m in methods): |                 if not any(m in member_type.__dict__ for m in methods): | ||||||
|                     _make_class_unpicklable(enum_class) |                     if '__new__' in classdict: | ||||||
| 
 |                         # too late, sabotage | ||||||
|  |                         _make_class_unpicklable(enum_class) | ||||||
|  |                     else: | ||||||
|  |                         # final attempt to verify that pickling would work: | ||||||
|  |                         # travel mro until __new__ is found, checking for | ||||||
|  |                         # __reduce__ and friends along the way -- if any of them | ||||||
|  |                         # are found before/when __new__ is found, pickling should | ||||||
|  |                         # work | ||||||
|  |                         sabotage = None | ||||||
|  |                         for chain in bases: | ||||||
|  |                             for base in chain.__mro__: | ||||||
|  |                                 if base is object: | ||||||
|  |                                     continue | ||||||
|  |                                 elif any(m in base.__dict__ for m in methods): | ||||||
|  |                                     # found one, we're good | ||||||
|  |                                     sabotage = False | ||||||
|  |                                     break | ||||||
|  |                                 elif '__new__' in base.__dict__: | ||||||
|  |                                     # not good | ||||||
|  |                                     sabotage = True | ||||||
|  |                                     break | ||||||
|  |                             if sabotage is not None: | ||||||
|  |                                 break | ||||||
|  |                         if sabotage: | ||||||
|  |                             _make_class_unpicklable(enum_class) | ||||||
|         # instantiate them, checking for duplicates as we go |         # instantiate them, checking for duplicates as we go | ||||||
|         # we instantiate first instead of checking for duplicates first in case |         # we instantiate first instead of checking for duplicates first in case | ||||||
|         # a custom __new__ is doing something funky with the values -- such as |         # a custom __new__ is doing something funky with the values -- such as | ||||||
|  | @ -572,7 +596,7 @@ def _find_data_type(bases): | ||||||
|                         data_types.add(candidate or base) |                         data_types.add(candidate or base) | ||||||
|                         break |                         break | ||||||
|                     else: |                     else: | ||||||
|                         candidate = base |                         candidate = candidate or base | ||||||
|             if len(data_types) > 1: |             if len(data_types) > 1: | ||||||
|                 raise TypeError('%r: too many data types: %r' % (class_name, data_types)) |                 raise TypeError('%r: too many data types: %r' % (class_name, data_types)) | ||||||
|             elif data_types: |             elif data_types: | ||||||
|  |  | ||||||
|  | @ -594,13 +594,49 @@ class Test2Enum(MyStrEnum, MyMethodEnum): | ||||||
| 
 | 
 | ||||||
|     def test_inherited_data_type(self): |     def test_inherited_data_type(self): | ||||||
|         class HexInt(int): |         class HexInt(int): | ||||||
|  |             __qualname__ = 'HexInt' | ||||||
|             def __repr__(self): |             def __repr__(self): | ||||||
|                 return hex(self) |                 return hex(self) | ||||||
|         class MyEnum(HexInt, enum.Enum): |         class MyEnum(HexInt, enum.Enum): | ||||||
|  |             __qualname__ = 'MyEnum' | ||||||
|             A = 1 |             A = 1 | ||||||
|             B = 2 |             B = 2 | ||||||
|             C = 3 |             C = 3 | ||||||
|         self.assertEqual(repr(MyEnum.A), '<MyEnum.A: 0x1>') |         self.assertEqual(repr(MyEnum.A), '<MyEnum.A: 0x1>') | ||||||
|  |         globals()['HexInt'] = HexInt | ||||||
|  |         globals()['MyEnum'] = MyEnum | ||||||
|  |         test_pickle_dump_load(self.assertIs, MyEnum.A) | ||||||
|  |         test_pickle_dump_load(self.assertIs, MyEnum) | ||||||
|  |         # | ||||||
|  |         class SillyInt(HexInt): | ||||||
|  |             __qualname__ = 'SillyInt' | ||||||
|  |             pass | ||||||
|  |         class MyOtherEnum(SillyInt, enum.Enum): | ||||||
|  |             __qualname__ = 'MyOtherEnum' | ||||||
|  |             D = 4 | ||||||
|  |             E = 5 | ||||||
|  |             F = 6 | ||||||
|  |         self.assertIs(MyOtherEnum._member_type_, SillyInt) | ||||||
|  |         globals()['SillyInt'] = SillyInt | ||||||
|  |         globals()['MyOtherEnum'] = MyOtherEnum | ||||||
|  |         test_pickle_dump_load(self.assertIs, MyOtherEnum.E) | ||||||
|  |         test_pickle_dump_load(self.assertIs, MyOtherEnum) | ||||||
|  |         # | ||||||
|  |         class BrokenInt(int): | ||||||
|  |             __qualname__ = 'BrokenInt' | ||||||
|  |             def __new__(cls, value): | ||||||
|  |                 return int.__new__(cls, value) | ||||||
|  |         class MyBrokenEnum(BrokenInt, Enum): | ||||||
|  |             __qualname__ = 'MyBrokenEnum' | ||||||
|  |             G = 7 | ||||||
|  |             H = 8 | ||||||
|  |             I = 9 | ||||||
|  |         self.assertIs(MyBrokenEnum._member_type_, BrokenInt) | ||||||
|  |         self.assertIs(MyBrokenEnum(7), MyBrokenEnum.G) | ||||||
|  |         globals()['BrokenInt'] = BrokenInt | ||||||
|  |         globals()['MyBrokenEnum'] = MyBrokenEnum | ||||||
|  |         test_pickle_exception(self.assertRaises, TypeError, MyBrokenEnum.G) | ||||||
|  |         test_pickle_exception(self.assertRaises, PicklingError, MyBrokenEnum) | ||||||
| 
 | 
 | ||||||
|     def test_too_many_data_types(self): |     def test_too_many_data_types(self): | ||||||
|         with self.assertRaisesRegex(TypeError, 'too many data types'): |         with self.assertRaisesRegex(TypeError, 'too many data types'): | ||||||
|  |  | ||||||
|  | @ -0,0 +1,2 @@ | ||||||
|  | [Enum] Be more robust in searching for pickle support before making an enum | ||||||
|  | class unpicklable. | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Ethan Furman
						Ethan Furman