[3.14] gh-116738: Make grp module thread-safe (GH-135434) (#136658)

gh-116738: Make grp module thread-safe (GH-135434)

Make grp module methods getgrgid() and getgrnam() thread-safe when the GIL is disabled and getgrgid_r()/getgrnam_r() C APIs are not available.
---------
(cherry picked from commit 9363703bd3)

Co-authored-by: Alper <alperyoney@fb.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
Miss Islington (bot) 2025-07-15 07:33:33 +02:00 committed by GitHub
parent bbbbb2e2d1
commit 55eaaab8a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 115 additions and 46 deletions

View file

@ -1384,6 +1384,13 @@ The :mod:`test.support.threading_helper` module provides support for threading t
.. versionadded:: 3.8 .. versionadded:: 3.8
.. function:: run_concurrently(worker_func, nthreads, args=(), kwargs={})
Run the worker function concurrently in multiple threads.
Re-raises an exception if any thread raises one, after all threads have
finished.
:mod:`test.support.os_helper` --- Utilities for os tests :mod:`test.support.os_helper` --- Utilities for os tests
======================================================================== ========================================================================

View file

@ -248,3 +248,27 @@ def requires_working_threading(*, module=False):
raise unittest.SkipTest(msg) raise unittest.SkipTest(msg)
else: else:
return unittest.skipUnless(can_start_thread, msg) return unittest.skipUnless(can_start_thread, msg)
def run_concurrently(worker_func, nthreads, args=(), kwargs={}):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = threading.Barrier(nthreads)
def wrapper_func(*args, **kwargs):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args, **kwargs)
with catch_threading_exception() as cm:
workers = [
threading.Thread(target=wrapper_func, args=args, kwargs=kwargs)
for _ in range(nthreads)
]
with start_threads(workers):
pass
# If a worker thread raises an exception, re-raise it.
if cm.exc_value is not None:
raise cm.exc_value

View file

@ -0,0 +1,35 @@
import unittest
from test.support import import_helper, threading_helper
from test.support.threading_helper import run_concurrently
grp = import_helper.import_module("grp")
from test import test_grp
NTHREADS = 10
@threading_helper.requires_working_threading()
class TestGrp(unittest.TestCase):
def setUp(self):
self.test_grp = test_grp.GroupDatabaseTestCase()
def test_racing_test_values(self):
# test_grp.test_values() calls grp.getgrall() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values, nthreads=NTHREADS
)
def test_racing_test_values_extended(self):
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
# grp.getgrnam() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values_extended,
nthreads=NTHREADS,
)
if __name__ == "__main__":
unittest.main()

View file

@ -3,10 +3,11 @@
import heapq import heapq
from enum import Enum from enum import Enum
from threading import Thread, Barrier, Lock from threading import Barrier, Lock
from random import shuffle, randint from random import shuffle, randint
from test.support import threading_helper from test.support import threading_helper
from test.support.threading_helper import run_concurrently
from test import test_heapq from test import test_heapq
@ -28,8 +29,8 @@ def test_racing_heapify(self):
heap = list(range(OBJECT_COUNT)) heap = list(range(OBJECT_COUNT))
shuffle(heap) shuffle(heap)
self.run_concurrently( run_concurrently(
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,)
) )
self.test_heapq.check_invariant(heap) self.test_heapq.check_invariant(heap)
@ -40,8 +41,8 @@ def heappush_func(heap):
for item in reversed(range(OBJECT_COUNT)): for item in reversed(range(OBJECT_COUNT)):
heapq.heappush(heap, item) heapq.heappush(heap, item)
self.run_concurrently( run_concurrently(
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS worker_func=heappush_func, nthreads=NTHREADS, args=(heap,)
) )
self.test_heapq.check_invariant(heap) self.test_heapq.check_invariant(heap)
@ -61,10 +62,10 @@ def heappop_func(heap, pop_count):
# Each local list should be sorted # Each local list should be sorted
self.assertTrue(self.is_sorted_ascending(local_list)) self.assertTrue(self.is_sorted_ascending(local_list))
self.run_concurrently( run_concurrently(
worker_func=heappop_func, worker_func=heappop_func,
args=(heap, per_thread_pop_count),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(heap, per_thread_pop_count),
) )
self.assertEqual(len(heap), 0) self.assertEqual(len(heap), 0)
@ -77,10 +78,10 @@ def heappushpop_func(heap, pushpop_items):
popped_item = heapq.heappushpop(heap, item) popped_item = heapq.heappushpop(heap, item)
self.assertTrue(popped_item <= item) self.assertTrue(popped_item <= item)
self.run_concurrently( run_concurrently(
worker_func=heappushpop_func, worker_func=heappushpop_func,
args=(heap, pushpop_items),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(heap, pushpop_items),
) )
self.assertEqual(len(heap), OBJECT_COUNT) self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap) self.test_heapq.check_invariant(heap)
@ -93,10 +94,10 @@ def heapreplace_func(heap, replace_items):
for item in replace_items: for item in replace_items:
heapq.heapreplace(heap, item) heapq.heapreplace(heap, item)
self.run_concurrently( run_concurrently(
worker_func=heapreplace_func, worker_func=heapreplace_func,
args=(heap, replace_items),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(heap, replace_items),
) )
self.assertEqual(len(heap), OBJECT_COUNT) self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap) self.test_heapq.check_invariant(heap)
@ -105,8 +106,8 @@ def test_racing_heapify_max(self):
max_heap = list(range(OBJECT_COUNT)) max_heap = list(range(OBJECT_COUNT))
shuffle(max_heap) shuffle(max_heap)
self.run_concurrently( run_concurrently(
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,)
) )
self.test_heapq.check_max_invariant(max_heap) self.test_heapq.check_max_invariant(max_heap)
@ -117,8 +118,8 @@ def heappush_max_func(max_heap):
for item in range(OBJECT_COUNT): for item in range(OBJECT_COUNT):
heapq.heappush_max(max_heap, item) heapq.heappush_max(max_heap, item)
self.run_concurrently( run_concurrently(
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,)
) )
self.test_heapq.check_max_invariant(max_heap) self.test_heapq.check_max_invariant(max_heap)
@ -138,10 +139,10 @@ def heappop_max_func(max_heap, pop_count):
# Each local list should be sorted # Each local list should be sorted
self.assertTrue(self.is_sorted_descending(local_list)) self.assertTrue(self.is_sorted_descending(local_list))
self.run_concurrently( run_concurrently(
worker_func=heappop_max_func, worker_func=heappop_max_func,
args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(max_heap, per_thread_pop_count),
) )
self.assertEqual(len(max_heap), 0) self.assertEqual(len(max_heap), 0)
@ -154,10 +155,10 @@ def heappushpop_max_func(max_heap, pushpop_items):
popped_item = heapq.heappushpop_max(max_heap, item) popped_item = heapq.heappushpop_max(max_heap, item)
self.assertTrue(popped_item >= item) self.assertTrue(popped_item >= item)
self.run_concurrently( run_concurrently(
worker_func=heappushpop_max_func, worker_func=heappushpop_max_func,
args=(max_heap, pushpop_items),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(max_heap, pushpop_items),
) )
self.assertEqual(len(max_heap), OBJECT_COUNT) self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap) self.test_heapq.check_max_invariant(max_heap)
@ -170,10 +171,10 @@ def heapreplace_max_func(max_heap, replace_items):
for item in replace_items: for item in replace_items:
heapq.heapreplace_max(max_heap, item) heapq.heapreplace_max(max_heap, item)
self.run_concurrently( run_concurrently(
worker_func=heapreplace_max_func, worker_func=heapreplace_max_func,
args=(max_heap, replace_items),
nthreads=NTHREADS, nthreads=NTHREADS,
args=(max_heap, replace_items),
) )
self.assertEqual(len(max_heap), OBJECT_COUNT) self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap) self.test_heapq.check_max_invariant(max_heap)
@ -203,7 +204,7 @@ def worker():
except IndexError: except IndexError:
pass pass
self.run_concurrently(worker, (), n_threads * 2) run_concurrently(worker, n_threads * 2)
@staticmethod @staticmethod
def is_sorted_ascending(lst): def is_sorted_ascending(lst):
@ -241,27 +242,6 @@ def create_random_list(a, b, size):
""" """
return [randint(-a, b) for _ in range(size)] return [randint(-a, b) for _ in range(size)]
def run_concurrently(self, worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = Barrier(nthreads)
def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)
with threading_helper.catch_threading_exception() as cm:
workers = (
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
)
with threading_helper.start_threads(workers):
pass
# Worker threads should not raise any exceptions
self.assertIsNone(cm.exc_value)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -0,0 +1 @@
Make functions in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.

View file

@ -55,6 +55,11 @@ get_grp_state(PyObject *module)
static struct PyModuleDef grpmodule; static struct PyModuleDef grpmodule;
/* Mutex to protect calls to getgrgid(), getgrnam(), and getgrent().
* These functions return pointer to static data structure, which
* may be overwritten by any subsequent calls. */
static PyMutex group_db_mutex = {0};
#define DEFAULT_BUFFER_SIZE 1024 #define DEFAULT_BUFFER_SIZE 1024
static PyObject * static PyObject *
@ -168,9 +173,15 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
#else #else
PyMutex_Lock(&group_db_mutex);
// The getgrgid() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
p = getgrgid(gid); p = getgrgid(gid);
#endif #endif
if (p == NULL) { if (p == NULL) {
#ifndef HAVE_GETGRGID_R
PyMutex_Unlock(&group_db_mutex);
#endif
PyMem_RawFree(buf); PyMem_RawFree(buf);
if (nomem == 1) { if (nomem == 1) {
return PyErr_NoMemory(); return PyErr_NoMemory();
@ -185,6 +196,8 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
retval = mkgrent(module, p); retval = mkgrent(module, p);
#ifdef HAVE_GETGRGID_R #ifdef HAVE_GETGRGID_R
PyMem_RawFree(buf); PyMem_RawFree(buf);
#else
PyMutex_Unlock(&group_db_mutex);
#endif #endif
return retval; return retval;
} }
@ -249,9 +262,15 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
#else #else
PyMutex_Lock(&group_db_mutex);
// The getgrnam() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
p = getgrnam(name_chars); p = getgrnam(name_chars);
#endif #endif
if (p == NULL) { if (p == NULL) {
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&group_db_mutex);
#endif
if (nomem == 1) { if (nomem == 1) {
PyErr_NoMemory(); PyErr_NoMemory();
} }
@ -261,6 +280,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
goto out; goto out;
} }
retval = mkgrent(module, p); retval = mkgrent(module, p);
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&group_db_mutex);
#endif
out: out:
PyMem_RawFree(buf); PyMem_RawFree(buf);
Py_DECREF(bytes); Py_DECREF(bytes);
@ -285,8 +307,7 @@ grp_getgrall_impl(PyObject *module)
return NULL; return NULL;
} }
static PyMutex getgrall_mutex = {0}; PyMutex_Lock(&group_db_mutex);
PyMutex_Lock(&getgrall_mutex);
setgrent(); setgrent();
struct group *p; struct group *p;
@ -306,7 +327,7 @@ grp_getgrall_impl(PyObject *module)
done: done:
endgrent(); endgrent();
PyMutex_Unlock(&getgrall_mutex); PyMutex_Unlock(&group_db_mutex);
return d; return d;
} }

View file

@ -167,6 +167,7 @@ Python/sysmodule.c - _preinit_xoptions -
# XXX need race protection? # XXX need race protection?
Modules/faulthandler.c faulthandler_dump_traceback reentrant - Modules/faulthandler.c faulthandler_dump_traceback reentrant -
Modules/faulthandler.c faulthandler_dump_c_stack reentrant - Modules/faulthandler.c faulthandler_dump_c_stack reentrant -
Modules/grpmodule.c - group_db_mutex -
Python/pylifecycle.c _Py_FatalErrorFormat reentrant - Python/pylifecycle.c _Py_FatalErrorFormat reentrant -
Python/pylifecycle.c fatal_error reentrant - Python/pylifecycle.c fatal_error reentrant -

Can't render this file because it has a wrong number of fields in line 4.