mirror of
https://github.com/python/cpython.git
synced 2026-01-04 06:22:20 +00:00
gh-123471: make itertools.batched thread-safe (#129416)
This commit is contained in:
parent
155c44b901
commit
405a2d74cb
3 changed files with 51 additions and 2 deletions
39
Lib/test/test_free_threading/test_itertools_batched.py
Normal file
39
Lib/test/test_free_threading/test_itertools_batched.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import unittest
|
||||
import sys
|
||||
from threading import Thread, Barrier
|
||||
from itertools import batched
|
||||
from test.support import threading_helper
|
||||
|
||||
|
||||
threading_helper.requires_working_threading(module=True)
|
||||
|
||||
class EnumerateThreading(unittest.TestCase):
|
||||
|
||||
@threading_helper.reap_threads
|
||||
def test_threading(self):
|
||||
number_of_threads = 10
|
||||
number_of_iterations = 20
|
||||
barrier = Barrier(number_of_threads)
|
||||
def work(it):
|
||||
barrier.wait()
|
||||
while True:
|
||||
try:
|
||||
_ = next(it)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
data = tuple(range(1000))
|
||||
for it in range(number_of_iterations):
|
||||
batch_iterator = batched(data, 2)
|
||||
worker_threads = []
|
||||
for ii in range(number_of_threads):
|
||||
worker_threads.append(
|
||||
Thread(target=work, args=[batch_iterator]))
|
||||
|
||||
with threading_helper.start_threads(worker_threads):
|
||||
pass
|
||||
|
||||
barrier.reset()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1 @@
|
|||
Make concurrent iterations over :class:`itertools.batched` safe under free-threading.
|
||||
|
|
@ -191,12 +191,12 @@ batched_next(PyObject *op)
|
|||
{
|
||||
batchedobject *bo = batchedobject_CAST(op);
|
||||
Py_ssize_t i;
|
||||
Py_ssize_t n = bo->batch_size;
|
||||
Py_ssize_t n = FT_ATOMIC_LOAD_SSIZE_RELAXED(bo->batch_size);
|
||||
PyObject *it = bo->it;
|
||||
PyObject *item;
|
||||
PyObject *result;
|
||||
|
||||
if (it == NULL) {
|
||||
if (n < 0) {
|
||||
return NULL;
|
||||
}
|
||||
result = PyTuple_New(n);
|
||||
|
|
@ -218,19 +218,28 @@ batched_next(PyObject *op)
|
|||
if (PyErr_Occurred()) {
|
||||
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
||||
/* Input raised an exception other than StopIteration */
|
||||
FT_ATOMIC_STORE_SSIZE_RELAXED(bo->batch_size, -1);
|
||||
#ifndef Py_GIL_DISABLED
|
||||
Py_CLEAR(bo->it);
|
||||
#endif
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
PyErr_Clear();
|
||||
}
|
||||
if (i == 0) {
|
||||
FT_ATOMIC_STORE_SSIZE_RELAXED(bo->batch_size, -1);
|
||||
#ifndef Py_GIL_DISABLED
|
||||
Py_CLEAR(bo->it);
|
||||
#endif
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
if (bo->strict) {
|
||||
FT_ATOMIC_STORE_SSIZE_RELAXED(bo->batch_size, -1);
|
||||
#ifndef Py_GIL_DISABLED
|
||||
Py_CLEAR(bo->it);
|
||||
#endif
|
||||
Py_DECREF(result);
|
||||
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
|
||||
return NULL;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue