mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 05:31:20 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			177 lines
		
	
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			177 lines
		
	
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import gc
 | |
| import time
 | |
| import unittest
 | |
| import weakref
 | |
| 
 | |
| from ast import Or
 | |
| from functools import partial
 | |
| from threading import Thread
 | |
| from unittest import TestCase
 | |
| 
 | |
| from _testcapi import dict_version
 | |
| 
 | |
| from test.support import threading_helper
 | |
| 
 | |
| 
 | |
| @threading_helper.requires_working_threading()
 | |
| class TestDict(TestCase):
 | |
|     def test_racing_creation_shared_keys(self):
 | |
|         """Verify that creating dictionaries is thread safe when we
 | |
|         have a type with shared keys"""
 | |
|         class C(int):
 | |
|             pass
 | |
| 
 | |
|         self.racing_creation(C)
 | |
| 
 | |
|     def test_racing_creation_no_shared_keys(self):
 | |
|         """Verify that creating dictionaries is thread safe when we
 | |
|         have a type with an ordinary dict"""
 | |
|         self.racing_creation(Or)
 | |
| 
 | |
|     def test_racing_creation_inline_values_invalid(self):
 | |
|         """Verify that re-creating a dict after we have invalid inline values
 | |
|         is thread safe"""
 | |
|         class C:
 | |
|             pass
 | |
| 
 | |
|         def make_obj():
 | |
|             a = C()
 | |
|             # Make object, make inline values invalid, and then delete dict
 | |
|             a.__dict__ = {}
 | |
|             del a.__dict__
 | |
|             return a
 | |
| 
 | |
|         self.racing_creation(make_obj)
 | |
| 
 | |
|     def test_racing_creation_nonmanaged_dict(self):
 | |
|         """Verify that explicit creation of an unmanaged dict is thread safe
 | |
|         outside of the normal attribute setting code path"""
 | |
|         def make_obj():
 | |
|             def f(): pass
 | |
|             return f
 | |
| 
 | |
|         def set(func, name, val):
 | |
|             # Force creation of the dict via PyObject_GenericGetDict
 | |
|             func.__dict__[name] = val
 | |
| 
 | |
|         self.racing_creation(make_obj, set)
 | |
| 
 | |
|     def racing_creation(self, cls, set=setattr):
 | |
|         objects = []
 | |
|         processed = []
 | |
| 
 | |
|         OBJECT_COUNT = 100
 | |
|         THREAD_COUNT = 10
 | |
|         CUR = 0
 | |
| 
 | |
|         for i in range(OBJECT_COUNT):
 | |
|             objects.append(cls())
 | |
| 
 | |
|         def writer_func(name):
 | |
|             last = -1
 | |
|             while True:
 | |
|                 if CUR == last:
 | |
|                     continue
 | |
|                 elif CUR == OBJECT_COUNT:
 | |
|                     break
 | |
| 
 | |
|                 obj = objects[CUR]
 | |
|                 set(obj, name, name)
 | |
|                 last = CUR
 | |
|                 processed.append(name)
 | |
| 
 | |
|         writers = []
 | |
|         for x in range(THREAD_COUNT):
 | |
|             writer = Thread(target=partial(writer_func, f"a{x:02}"))
 | |
|             writers.append(writer)
 | |
|             writer.start()
 | |
| 
 | |
|         for i in range(OBJECT_COUNT):
 | |
|             CUR = i
 | |
|             while len(processed) != THREAD_COUNT:
 | |
|                 time.sleep(0.001)
 | |
|             processed.clear()
 | |
| 
 | |
|         CUR = OBJECT_COUNT
 | |
| 
 | |
|         for writer in writers:
 | |
|             writer.join()
 | |
| 
 | |
|         for obj_idx, obj in enumerate(objects):
 | |
|             assert (
 | |
|                 len(obj.__dict__) == THREAD_COUNT
 | |
|             ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
 | |
|             for i in range(THREAD_COUNT):
 | |
|                 assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
 | |
| 
 | |
|     def test_racing_set_dict(self):
 | |
|         """Races assigning to __dict__ should be thread safe"""
 | |
| 
 | |
|         def f(): pass
 | |
|         l = []
 | |
|         THREAD_COUNT = 10
 | |
|         class MyDict(dict): pass
 | |
| 
 | |
|         def writer_func(l):
 | |
|             for i in range(1000):
 | |
|                 d = MyDict()
 | |
|                 l.append(weakref.ref(d))
 | |
|                 f.__dict__ = d
 | |
| 
 | |
|         lists = []
 | |
|         writers = []
 | |
|         for x in range(THREAD_COUNT):
 | |
|             thread_list = []
 | |
|             lists.append(thread_list)
 | |
|             writer = Thread(target=partial(writer_func, thread_list))
 | |
|             writers.append(writer)
 | |
| 
 | |
|         for writer in writers:
 | |
|             writer.start()
 | |
| 
 | |
|         for writer in writers:
 | |
|             writer.join()
 | |
| 
 | |
|         f.__dict__ = {}
 | |
|         gc.collect()
 | |
| 
 | |
|         for thread_list in lists:
 | |
|             for ref in thread_list:
 | |
|                 self.assertIsNone(ref())
 | |
| 
 | |
|     def test_dict_version(self):
 | |
|         THREAD_COUNT = 10
 | |
|         DICT_COUNT = 10000
 | |
|         lists = []
 | |
|         writers = []
 | |
| 
 | |
|         def writer_func(thread_list):
 | |
|             for i in range(DICT_COUNT):
 | |
|                 thread_list.append(dict_version({}))
 | |
| 
 | |
|         for x in range(THREAD_COUNT):
 | |
|             thread_list = []
 | |
|             lists.append(thread_list)
 | |
|             writer = Thread(target=partial(writer_func, thread_list))
 | |
|             writers.append(writer)
 | |
| 
 | |
|         for writer in writers:
 | |
|             writer.start()
 | |
| 
 | |
|         for writer in writers:
 | |
|             writer.join()
 | |
| 
 | |
|         total_len = 0
 | |
|         values = set()
 | |
|         for thread_list in lists:
 | |
|             for v in thread_list:
 | |
|                 if v in values:
 | |
|                     print('dup', v, (v/4096)%256)
 | |
|                 values.add(v)
 | |
|             total_len += len(thread_list)
 | |
|         versions = set(dict_version for thread_list in lists for dict_version in thread_list)
 | |
|         self.assertEqual(len(versions), THREAD_COUNT*DICT_COUNT)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
