[3.13] gh-142752: add more thread safety tests for mock (GH-142791) (#142857)

gh-142752: add more thread safety tests for mock (GH-142791)
(cherry picked from commit 4fd006e712)

Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
Miss Islington (bot) 2025-12-17 09:09:59 +01:00 committed by GitHub
parent fb5474726c
commit e07cda302e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -219,5 +219,81 @@ def test_function():
self.assertEqual(m.call_count, LOOPS * THREADS)
def test_call_args_thread_safe(self):
m = ThreadingMock()
LOOPS = 100
THREADS = 10
def test_function(thread_id):
for i in range(LOOPS):
m(thread_id, i)
oldswitchinterval = sys.getswitchinterval()
setswitchinterval(1e-6)
try:
threads = [
threading.Thread(target=test_function, args=(thread_id,))
for thread_id in range(THREADS)
]
with threading_helper.start_threads(threads):
pass
finally:
sys.setswitchinterval(oldswitchinterval)
expected_calls = {
(thread_id, i)
for thread_id in range(THREADS)
for i in range(LOOPS)
}
self.assertSetEqual({call.args for call in m.call_args_list}, expected_calls)
def test_method_calls_thread_safe(self):
m = ThreadingMock()
LOOPS = 100
THREADS = 10
def test_function(thread_id):
for i in range(LOOPS):
getattr(m, f"method_{thread_id}")(i)
oldswitchinterval = sys.getswitchinterval()
setswitchinterval(1e-6)
try:
threads = [
threading.Thread(target=test_function, args=(thread_id,))
for thread_id in range(THREADS)
]
with threading_helper.start_threads(threads):
pass
finally:
sys.setswitchinterval(oldswitchinterval)
for thread_id in range(THREADS):
self.assertEqual(getattr(m, f"method_{thread_id}").call_count, LOOPS)
self.assertEqual({call.args for call in getattr(m, f"method_{thread_id}").call_args_list},
{(i,) for i in range(LOOPS)})
def test_mock_calls_thread_safe(self):
m = ThreadingMock()
LOOPS = 100
THREADS = 10
def test_function(thread_id):
for i in range(LOOPS):
m(thread_id, i)
oldswitchinterval = sys.getswitchinterval()
setswitchinterval(1e-6)
try:
threads = [
threading.Thread(target=test_function, args=(thread_id,))
for thread_id in range(THREADS)
]
with threading_helper.start_threads(threads):
pass
finally:
sys.setswitchinterval(oldswitchinterval)
expected_calls = {
(thread_id, i)
for thread_id in range(THREADS)
for i in range(LOOPS)
}
self.assertSetEqual({call.args for call in m.mock_calls}, expected_calls)
if __name__ == "__main__":
unittest.main()