mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 21:51:50 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			267 lines
		
	
	
	
		
			8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			267 lines
		
	
	
	
		
			8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import unittest
 | |
| 
 | |
| import heapq
 | |
| 
 | |
| from enum import Enum
 | |
| from threading import Thread, Barrier, Lock
 | |
| from random import shuffle, randint
 | |
| 
 | |
| from test.support import threading_helper
 | |
| from test import test_heapq
 | |
| 
 | |
| 
 | |
| NTHREADS = 10
 | |
| OBJECT_COUNT = 5_000
 | |
| 
 | |
| 
 | |
| class Heap(Enum):
 | |
|     MIN = 1
 | |
|     MAX = 2
 | |
| 
 | |
| 
 | |
| @threading_helper.requires_working_threading()
 | |
| class TestHeapq(unittest.TestCase):
 | |
|     def setUp(self):
 | |
|         self.test_heapq = test_heapq.TestHeapPython()
 | |
| 
 | |
|     def test_racing_heapify(self):
 | |
|         heap = list(range(OBJECT_COUNT))
 | |
|         shuffle(heap)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
 | |
|         )
 | |
|         self.test_heapq.check_invariant(heap)
 | |
| 
 | |
|     def test_racing_heappush(self):
 | |
|         heap = []
 | |
| 
 | |
|         def heappush_func(heap):
 | |
|             for item in reversed(range(OBJECT_COUNT)):
 | |
|                 heapq.heappush(heap, item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
 | |
|         )
 | |
|         self.test_heapq.check_invariant(heap)
 | |
| 
 | |
|     def test_racing_heappop(self):
 | |
|         heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
 | |
| 
 | |
|         # Each thread pops (OBJECT_COUNT / NTHREADS) items
 | |
|         self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
 | |
|         per_thread_pop_count = OBJECT_COUNT // NTHREADS
 | |
| 
 | |
|         def heappop_func(heap, pop_count):
 | |
|             local_list = []
 | |
|             for _ in range(pop_count):
 | |
|                 item = heapq.heappop(heap)
 | |
|                 local_list.append(item)
 | |
| 
 | |
|             # Each local list should be sorted
 | |
|             self.assertTrue(self.is_sorted_ascending(local_list))
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappop_func,
 | |
|             args=(heap, per_thread_pop_count),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(heap), 0)
 | |
| 
 | |
|     def test_racing_heappushpop(self):
 | |
|         heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
 | |
|         pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
 | |
| 
 | |
|         def heappushpop_func(heap, pushpop_items):
 | |
|             for item in pushpop_items:
 | |
|                 popped_item = heapq.heappushpop(heap, item)
 | |
|                 self.assertTrue(popped_item <= item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappushpop_func,
 | |
|             args=(heap, pushpop_items),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(heap), OBJECT_COUNT)
 | |
|         self.test_heapq.check_invariant(heap)
 | |
| 
 | |
|     def test_racing_heapreplace(self):
 | |
|         heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
 | |
|         replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
 | |
| 
 | |
|         def heapreplace_func(heap, replace_items):
 | |
|             for item in replace_items:
 | |
|                 heapq.heapreplace(heap, item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heapreplace_func,
 | |
|             args=(heap, replace_items),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(heap), OBJECT_COUNT)
 | |
|         self.test_heapq.check_invariant(heap)
 | |
| 
 | |
|     def test_racing_heapify_max(self):
 | |
|         max_heap = list(range(OBJECT_COUNT))
 | |
|         shuffle(max_heap)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
 | |
|         )
 | |
|         self.test_heapq.check_max_invariant(max_heap)
 | |
| 
 | |
|     def test_racing_heappush_max(self):
 | |
|         max_heap = []
 | |
| 
 | |
|         def heappush_max_func(max_heap):
 | |
|             for item in range(OBJECT_COUNT):
 | |
|                 heapq.heappush_max(max_heap, item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
 | |
|         )
 | |
|         self.test_heapq.check_max_invariant(max_heap)
 | |
| 
 | |
|     def test_racing_heappop_max(self):
 | |
|         max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
 | |
| 
 | |
|         # Each thread pops (OBJECT_COUNT / NTHREADS) items
 | |
|         self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
 | |
|         per_thread_pop_count = OBJECT_COUNT // NTHREADS
 | |
| 
 | |
|         def heappop_max_func(max_heap, pop_count):
 | |
|             local_list = []
 | |
|             for _ in range(pop_count):
 | |
|                 item = heapq.heappop_max(max_heap)
 | |
|                 local_list.append(item)
 | |
| 
 | |
|             # Each local list should be sorted
 | |
|             self.assertTrue(self.is_sorted_descending(local_list))
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappop_max_func,
 | |
|             args=(max_heap, per_thread_pop_count),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(max_heap), 0)
 | |
| 
 | |
|     def test_racing_heappushpop_max(self):
 | |
|         max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
 | |
|         pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
 | |
| 
 | |
|         def heappushpop_max_func(max_heap, pushpop_items):
 | |
|             for item in pushpop_items:
 | |
|                 popped_item = heapq.heappushpop_max(max_heap, item)
 | |
|                 self.assertTrue(popped_item >= item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heappushpop_max_func,
 | |
|             args=(max_heap, pushpop_items),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(max_heap), OBJECT_COUNT)
 | |
|         self.test_heapq.check_max_invariant(max_heap)
 | |
| 
 | |
|     def test_racing_heapreplace_max(self):
 | |
|         max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
 | |
|         replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
 | |
| 
 | |
|         def heapreplace_max_func(max_heap, replace_items):
 | |
|             for item in replace_items:
 | |
|                 heapq.heapreplace_max(max_heap, item)
 | |
| 
 | |
|         self.run_concurrently(
 | |
|             worker_func=heapreplace_max_func,
 | |
|             args=(max_heap, replace_items),
 | |
|             nthreads=NTHREADS,
 | |
|         )
 | |
|         self.assertEqual(len(max_heap), OBJECT_COUNT)
 | |
|         self.test_heapq.check_max_invariant(max_heap)
 | |
| 
 | |
|     def test_lock_free_list_read(self):
 | |
|         n, n_threads = 1_000, 10
 | |
|         l = []
 | |
|         barrier = Barrier(n_threads * 2)
 | |
| 
 | |
|         count = 0
 | |
|         lock = Lock()
 | |
| 
 | |
|         def worker():
 | |
|             with lock:
 | |
|                 nonlocal count
 | |
|                 x = count
 | |
|                 count += 1
 | |
| 
 | |
|             barrier.wait()
 | |
|             for i in range(n):
 | |
|                 if x % 2:
 | |
|                     heapq.heappush(l, 1)
 | |
|                     heapq.heappop(l)
 | |
|                 else:
 | |
|                     try:
 | |
|                         l[0]
 | |
|                     except IndexError:
 | |
|                         pass
 | |
| 
 | |
|         self.run_concurrently(worker, (), n_threads * 2)
 | |
| 
 | |
|     @staticmethod
 | |
|     def is_sorted_ascending(lst):
 | |
|         """
 | |
|         Check if the list is sorted in ascending order (non-decreasing).
 | |
|         """
 | |
|         return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
 | |
| 
 | |
|     @staticmethod
 | |
|     def is_sorted_descending(lst):
 | |
|         """
 | |
|         Check if the list is sorted in descending order (non-increasing).
 | |
|         """
 | |
|         return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
 | |
| 
 | |
|     @staticmethod
 | |
|     def create_heap(size, heap_kind):
 | |
|         """
 | |
|         Create a min/max heap where elements are in the range (0, size - 1) and
 | |
|         shuffled before heapify.
 | |
|         """
 | |
|         heap = list(range(OBJECT_COUNT))
 | |
|         shuffle(heap)
 | |
|         if heap_kind == Heap.MIN:
 | |
|             heapq.heapify(heap)
 | |
|         else:
 | |
|             heapq.heapify_max(heap)
 | |
| 
 | |
|         return heap
 | |
| 
 | |
|     @staticmethod
 | |
|     def create_random_list(a, b, size):
 | |
|         """
 | |
|         Create a list of random numbers between a and b (inclusive).
 | |
|         """
 | |
|         return [randint(-a, b) for _ in range(size)]
 | |
| 
 | |
|     def run_concurrently(self, worker_func, args, nthreads):
 | |
|         """
 | |
|         Run the worker function concurrently in multiple threads.
 | |
|         """
 | |
|         barrier = Barrier(nthreads)
 | |
| 
 | |
|         def wrapper_func(*args):
 | |
|             # Wait for all threads to reach this point before proceeding.
 | |
|             barrier.wait()
 | |
|             worker_func(*args)
 | |
| 
 | |
|         with threading_helper.catch_threading_exception() as cm:
 | |
|             workers = (
 | |
|                 Thread(target=wrapper_func, args=args) for _ in range(nthreads)
 | |
|             )
 | |
|             with threading_helper.start_threads(workers):
 | |
|                 pass
 | |
| 
 | |
|             # Worker threads should not raise any exceptions
 | |
|             self.assertIsNone(cm.exc_value)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
