Skip to content

Commit 678f29e

Browse files
committed
Make itertools.chain thread-safe
1 parent 17ac393 commit 678f29e

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

Lib/test/test_free_threading/test_itertools.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
from threading import Thread, Barrier
3-
from itertools import batched, cycle
3+
from itertools import batched, chain, cycle
44
from test.support import threading_helper
55

66

@@ -62,6 +62,35 @@ def work(it):
6262

6363
barrier.reset()
6464

65+
@threading_helper.reap_threads
66+
def test_chain(self):
67+
number_of_threads = 6
68+
number_of_iterations = 20
69+
70+
barrier = Barrier(number_of_threads)
71+
def work(it):
72+
barrier.wait()
73+
while True:
74+
try:
75+
_ = next(it)
76+
except StopIteration:
77+
break
78+
79+
80+
data = [(1, )] * 200
81+
for it in range(number_of_iterations):
82+
chain_iterator = chain(*data)
83+
worker_threads = []
84+
for ii in range(number_of_threads):
85+
worker_threads.append(
86+
Thread(target=work, args=[chain_iterator]))
87+
88+
with threading_helper.start_threads(worker_threads):
89+
pass
90+
91+
barrier.reset()
92+
93+
6594

6695
if __name__ == "__main__":
6796
unittest.main()

Modules/itertoolsmodule.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,8 +1880,8 @@ chain_traverse(PyObject *op, visitproc visit, void *arg)
18801880
return 0;
18811881
}
18821882

1883-
static PyObject *
1884-
chain_next(PyObject *op)
1883+
static inline PyObject *
1884+
chain_next_lock_held(PyObject *op)
18851885
{
18861886
chainobject *lz = chainobject_CAST(op);
18871887
PyObject *item;
@@ -1919,6 +1919,16 @@ chain_next(PyObject *op)
19191919
return NULL;
19201920
}
19211921

1922+
static PyObject *
1923+
chain_next(PyObject *op)
1924+
{
1925+
PyObject * result;
1926+
Py_BEGIN_CRITICAL_SECTION(op);
1927+
result = chain_next_lock_held(op);
1928+
Py_END_CRITICAL_SECTION()
1929+
return result;
1930+
}
1931+
19221932
PyDoc_STRVAR(chain_doc,
19231933
"chain(*iterables)\n\
19241934
--\n\

0 commit comments

Comments
 (0)