mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	bpo-30064: Fix asyncio loop.sock_* race condition issue (#20369)
This commit is contained in:
		
							parent
							
								
									526e23f153
								
							
						
					
					
						commit
						210a137396
					
				
					 3 changed files with 157 additions and 16 deletions
				
			
		| 
						 | 
					@ -266,6 +266,7 @@ def _add_reader(self, fd, callback, *args):
 | 
				
			||||||
                                  (handle, writer))
 | 
					                                  (handle, writer))
 | 
				
			||||||
            if reader is not None:
 | 
					            if reader is not None:
 | 
				
			||||||
                reader.cancel()
 | 
					                reader.cancel()
 | 
				
			||||||
 | 
					        return handle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _remove_reader(self, fd):
 | 
					    def _remove_reader(self, fd):
 | 
				
			||||||
        if self.is_closed():
 | 
					        if self.is_closed():
 | 
				
			||||||
| 
						 | 
					@ -302,6 +303,7 @@ def _add_writer(self, fd, callback, *args):
 | 
				
			||||||
                                  (reader, handle))
 | 
					                                  (reader, handle))
 | 
				
			||||||
            if writer is not None:
 | 
					            if writer is not None:
 | 
				
			||||||
                writer.cancel()
 | 
					                writer.cancel()
 | 
				
			||||||
 | 
					        return handle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _remove_writer(self, fd):
 | 
					    def _remove_writer(self, fd):
 | 
				
			||||||
        """Remove a writer callback."""
 | 
					        """Remove a writer callback."""
 | 
				
			||||||
| 
						 | 
					@ -329,7 +331,7 @@ def _remove_writer(self, fd):
 | 
				
			||||||
    def add_reader(self, fd, callback, *args):
 | 
					    def add_reader(self, fd, callback, *args):
 | 
				
			||||||
        """Add a reader callback."""
 | 
					        """Add a reader callback."""
 | 
				
			||||||
        self._ensure_fd_no_transport(fd)
 | 
					        self._ensure_fd_no_transport(fd)
 | 
				
			||||||
        return self._add_reader(fd, callback, *args)
 | 
					        self._add_reader(fd, callback, *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def remove_reader(self, fd):
 | 
					    def remove_reader(self, fd):
 | 
				
			||||||
        """Remove a reader callback."""
 | 
					        """Remove a reader callback."""
 | 
				
			||||||
| 
						 | 
					@ -339,7 +341,7 @@ def remove_reader(self, fd):
 | 
				
			||||||
    def add_writer(self, fd, callback, *args):
 | 
					    def add_writer(self, fd, callback, *args):
 | 
				
			||||||
        """Add a writer callback.."""
 | 
					        """Add a writer callback.."""
 | 
				
			||||||
        self._ensure_fd_no_transport(fd)
 | 
					        self._ensure_fd_no_transport(fd)
 | 
				
			||||||
        return self._add_writer(fd, callback, *args)
 | 
					        self._add_writer(fd, callback, *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def remove_writer(self, fd):
 | 
					    def remove_writer(self, fd):
 | 
				
			||||||
        """Remove a writer callback."""
 | 
					        """Remove a writer callback."""
 | 
				
			||||||
| 
						 | 
					@ -362,13 +364,15 @@ async def sock_recv(self, sock, n):
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
        fut = self.create_future()
 | 
					        fut = self.create_future()
 | 
				
			||||||
        fd = sock.fileno()
 | 
					        fd = sock.fileno()
 | 
				
			||||||
        self.add_reader(fd, self._sock_recv, fut, sock, n)
 | 
					        self._ensure_fd_no_transport(fd)
 | 
				
			||||||
 | 
					        handle = self._add_reader(fd, self._sock_recv, fut, sock, n)
 | 
				
			||||||
        fut.add_done_callback(
 | 
					        fut.add_done_callback(
 | 
				
			||||||
            functools.partial(self._sock_read_done, fd))
 | 
					            functools.partial(self._sock_read_done, fd, handle=handle))
 | 
				
			||||||
        return await fut
 | 
					        return await fut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_read_done(self, fd, fut):
 | 
					    def _sock_read_done(self, fd, fut, handle=None):
 | 
				
			||||||
        self.remove_reader(fd)
 | 
					        if handle is None or not handle.cancelled():
 | 
				
			||||||
 | 
					            self.remove_reader(fd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_recv(self, fut, sock, n):
 | 
					    def _sock_recv(self, fut, sock, n):
 | 
				
			||||||
        # _sock_recv() can add itself as an I/O callback if the operation can't
 | 
					        # _sock_recv() can add itself as an I/O callback if the operation can't
 | 
				
			||||||
| 
						 | 
					@ -401,9 +405,10 @@ async def sock_recv_into(self, sock, buf):
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
        fut = self.create_future()
 | 
					        fut = self.create_future()
 | 
				
			||||||
        fd = sock.fileno()
 | 
					        fd = sock.fileno()
 | 
				
			||||||
        self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
 | 
					        self._ensure_fd_no_transport(fd)
 | 
				
			||||||
 | 
					        handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf)
 | 
				
			||||||
        fut.add_done_callback(
 | 
					        fut.add_done_callback(
 | 
				
			||||||
            functools.partial(self._sock_read_done, fd))
 | 
					            functools.partial(self._sock_read_done, fd, handle=handle))
 | 
				
			||||||
        return await fut
 | 
					        return await fut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_recv_into(self, fut, sock, buf):
 | 
					    def _sock_recv_into(self, fut, sock, buf):
 | 
				
			||||||
| 
						 | 
					@ -446,11 +451,12 @@ async def sock_sendall(self, sock, data):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        fut = self.create_future()
 | 
					        fut = self.create_future()
 | 
				
			||||||
        fd = sock.fileno()
 | 
					        fd = sock.fileno()
 | 
				
			||||||
        fut.add_done_callback(
 | 
					        self._ensure_fd_no_transport(fd)
 | 
				
			||||||
            functools.partial(self._sock_write_done, fd))
 | 
					 | 
				
			||||||
        # use a trick with a list in closure to store a mutable state
 | 
					        # use a trick with a list in closure to store a mutable state
 | 
				
			||||||
        self.add_writer(fd, self._sock_sendall, fut, sock,
 | 
					        handle = self._add_writer(fd, self._sock_sendall, fut, sock,
 | 
				
			||||||
                        memoryview(data), [n])
 | 
					                                  memoryview(data), [n])
 | 
				
			||||||
 | 
					        fut.add_done_callback(
 | 
				
			||||||
 | 
					            functools.partial(self._sock_write_done, fd, handle=handle))
 | 
				
			||||||
        return await fut
 | 
					        return await fut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_sendall(self, fut, sock, view, pos):
 | 
					    def _sock_sendall(self, fut, sock, view, pos):
 | 
				
			||||||
| 
						 | 
					@ -502,9 +508,11 @@ def _sock_connect(self, fut, sock, address):
 | 
				
			||||||
            # connection runs in background. We have to wait until the socket
 | 
					            # connection runs in background. We have to wait until the socket
 | 
				
			||||||
            # becomes writable to be notified when the connection succeed or
 | 
					            # becomes writable to be notified when the connection succeed or
 | 
				
			||||||
            # fails.
 | 
					            # fails.
 | 
				
			||||||
 | 
					            self._ensure_fd_no_transport(fd)
 | 
				
			||||||
 | 
					            handle = self._add_writer(
 | 
				
			||||||
 | 
					                fd, self._sock_connect_cb, fut, sock, address)
 | 
				
			||||||
            fut.add_done_callback(
 | 
					            fut.add_done_callback(
 | 
				
			||||||
                functools.partial(self._sock_write_done, fd))
 | 
					                functools.partial(self._sock_write_done, fd, handle=handle))
 | 
				
			||||||
            self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
 | 
					 | 
				
			||||||
        except (SystemExit, KeyboardInterrupt):
 | 
					        except (SystemExit, KeyboardInterrupt):
 | 
				
			||||||
            raise
 | 
					            raise
 | 
				
			||||||
        except BaseException as exc:
 | 
					        except BaseException as exc:
 | 
				
			||||||
| 
						 | 
					@ -512,8 +520,9 @@ def _sock_connect(self, fut, sock, address):
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            fut.set_result(None)
 | 
					            fut.set_result(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_write_done(self, fd, fut):
 | 
					    def _sock_write_done(self, fd, fut, handle=None):
 | 
				
			||||||
        self.remove_writer(fd)
 | 
					        if handle is None or not handle.cancelled():
 | 
				
			||||||
 | 
					            self.remove_writer(fd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_connect_cb(self, fut, sock, address):
 | 
					    def _sock_connect_cb(self, fut, sock, address):
 | 
				
			||||||
        if fut.done():
 | 
					        if fut.done():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,5 @@
 | 
				
			||||||
import socket
 | 
					import socket
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from asyncio import proactor_events
 | 
					from asyncio import proactor_events
 | 
				
			||||||
| 
						 | 
					@ -122,6 +123,136 @@ def test_sock_client_ops(self):
 | 
				
			||||||
            sock = socket.socket()
 | 
					            sock = socket.socket()
 | 
				
			||||||
            self._basetest_sock_recv_into(httpd, sock)
 | 
					            self._basetest_sock_recv_into(httpd, sock)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _basetest_sock_recv_racing(self, httpd, sock):
 | 
				
			||||||
 | 
					        sock.setblocking(False)
 | 
				
			||||||
 | 
					        await self.loop.sock_connect(sock, httpd.address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        task.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        asyncio.create_task(
 | 
				
			||||||
 | 
					            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
 | 
				
			||||||
 | 
					        data = await self.loop.sock_recv(sock, 1024)
 | 
				
			||||||
 | 
					        # consume data
 | 
				
			||||||
 | 
					        await self.loop.sock_recv(sock, 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _basetest_sock_recv_into_racing(self, httpd, sock):
 | 
				
			||||||
 | 
					        sock.setblocking(False)
 | 
				
			||||||
 | 
					        await self.loop.sock_connect(sock, httpd.address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        data = bytearray(1024)
 | 
				
			||||||
 | 
					        with memoryview(data) as buf:
 | 
				
			||||||
 | 
					            task = asyncio.create_task(
 | 
				
			||||||
 | 
					                self.loop.sock_recv_into(sock, buf[:1024]))
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					            task.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            task = asyncio.create_task(
 | 
				
			||||||
 | 
					                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
 | 
				
			||||||
 | 
					            nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
 | 
				
			||||||
 | 
					            # consume data
 | 
				
			||||||
 | 
					            await self.loop.sock_recv_into(sock, buf[nbytes:])
 | 
				
			||||||
 | 
					            self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await task
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _basetest_sock_send_racing(self, listener, sock):
 | 
				
			||||||
 | 
					        listener.bind(('127.0.0.1', 0))
 | 
				
			||||||
 | 
					        listener.listen(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # make connection
 | 
				
			||||||
 | 
					        sock.setblocking(False)
 | 
				
			||||||
 | 
					        task = asyncio.create_task(
 | 
				
			||||||
 | 
					            self.loop.sock_connect(sock, listener.getsockname()))
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        server = listener.accept()[0]
 | 
				
			||||||
 | 
					        server.setblocking(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with server:
 | 
				
			||||||
 | 
					            await task
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # fill the buffer
 | 
				
			||||||
 | 
					            with self.assertRaises(BlockingIOError):
 | 
				
			||||||
 | 
					                while True:
 | 
				
			||||||
 | 
					                    sock.send(b' ' * 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # cancel a blocked sock_sendall
 | 
				
			||||||
 | 
					            task = asyncio.create_task(
 | 
				
			||||||
 | 
					                self.loop.sock_sendall(sock, b'hello'))
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					            task.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # clear the buffer
 | 
				
			||||||
 | 
					            async def recv_until():
 | 
				
			||||||
 | 
					                data = b''
 | 
				
			||||||
 | 
					                while not data:
 | 
				
			||||||
 | 
					                    data = await self.loop.sock_recv(server, 1024)
 | 
				
			||||||
 | 
					                    data = data.strip()
 | 
				
			||||||
 | 
					                return data
 | 
				
			||||||
 | 
					            task = asyncio.create_task(recv_until())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # immediately register another sock_sendall
 | 
				
			||||||
 | 
					            await self.loop.sock_sendall(sock, b'world')
 | 
				
			||||||
 | 
					            data = await task
 | 
				
			||||||
 | 
					            # ProactorEventLoop could deliver hello
 | 
				
			||||||
 | 
					            self.assertTrue(data.endswith(b'world'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _basetest_sock_connect_racing(self, listener, sock):
 | 
				
			||||||
 | 
					        listener.bind(('127.0.0.1', 0))
 | 
				
			||||||
 | 
					        addr = listener.getsockname()
 | 
				
			||||||
 | 
					        sock.setblocking(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        task = asyncio.create_task(self.loop.sock_connect(sock, addr))
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        task.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        listener.listen(1)
 | 
				
			||||||
 | 
					        i = 0
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                await self.loop.sock_connect(sock, addr)
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            except ConnectionRefusedError:  # on Linux we need another retry
 | 
				
			||||||
 | 
					                await self.loop.sock_connect(sock, addr)
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            except OSError as e:  # on Windows we need more retries
 | 
				
			||||||
 | 
					                # A connect request was made on an already connected socket
 | 
				
			||||||
 | 
					                if getattr(e, 'winerror', 0) == 10056:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # https://stackoverflow.com/a/54437602/3316267
 | 
				
			||||||
 | 
					                if getattr(e, 'winerror', 0) != 10022:
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					                if i >= 128:
 | 
				
			||||||
 | 
					                    raise  # too many retries
 | 
				
			||||||
 | 
					                # avoid touching event loop to maintain race condition
 | 
				
			||||||
 | 
					                time.sleep(0.01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sock_client_racing(self):
 | 
				
			||||||
 | 
					        with test_utils.run_test_server() as httpd:
 | 
				
			||||||
 | 
					            sock = socket.socket()
 | 
				
			||||||
 | 
					            with sock:
 | 
				
			||||||
 | 
					                self.loop.run_until_complete(asyncio.wait_for(
 | 
				
			||||||
 | 
					                    self._basetest_sock_recv_racing(httpd, sock), 10))
 | 
				
			||||||
 | 
					            sock = socket.socket()
 | 
				
			||||||
 | 
					            with sock:
 | 
				
			||||||
 | 
					                self.loop.run_until_complete(asyncio.wait_for(
 | 
				
			||||||
 | 
					                    self._basetest_sock_recv_into_racing(httpd, sock), 10))
 | 
				
			||||||
 | 
					        listener = socket.socket()
 | 
				
			||||||
 | 
					        sock = socket.socket()
 | 
				
			||||||
 | 
					        with listener, sock:
 | 
				
			||||||
 | 
					            self.loop.run_until_complete(asyncio.wait_for(
 | 
				
			||||||
 | 
					                self._basetest_sock_send_racing(listener, sock), 10))
 | 
				
			||||||
 | 
					        listener = socket.socket()
 | 
				
			||||||
 | 
					        sock = socket.socket()
 | 
				
			||||||
 | 
					        with listener, sock:
 | 
				
			||||||
 | 
					            self.loop.run_until_complete(asyncio.wait_for(
 | 
				
			||||||
 | 
					                self._basetest_sock_connect_racing(listener, sock), 10))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _basetest_huge_content(self, address):
 | 
					    async def _basetest_huge_content(self, address):
 | 
				
			||||||
        sock = socket.socket()
 | 
					        sock = socket.socket()
 | 
				
			||||||
        sock.setblocking(False)
 | 
					        sock.setblocking(False)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					Fix asyncio ``loop.sock_*`` race condition issue
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue