mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	bpo-30064: Fix asyncio loop.sock_* race condition issue (#20369)
This commit is contained in:
		
							parent
							
								
									526e23f153
								
							
						
					
					
						commit
						210a137396
					
				
					 3 changed files with 157 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -266,6 +266,7 @@ def _add_reader(self, fd, callback, *args):
 | 
			
		|||
                                  (handle, writer))
 | 
			
		||||
            if reader is not None:
 | 
			
		||||
                reader.cancel()
 | 
			
		||||
        return handle
 | 
			
		||||
 | 
			
		||||
    def _remove_reader(self, fd):
 | 
			
		||||
        if self.is_closed():
 | 
			
		||||
| 
						 | 
				
			
			@ -302,6 +303,7 @@ def _add_writer(self, fd, callback, *args):
 | 
			
		|||
                                  (reader, handle))
 | 
			
		||||
            if writer is not None:
 | 
			
		||||
                writer.cancel()
 | 
			
		||||
        return handle
 | 
			
		||||
 | 
			
		||||
    def _remove_writer(self, fd):
 | 
			
		||||
        """Remove a writer callback."""
 | 
			
		||||
| 
						 | 
				
			
			@ -329,7 +331,7 @@ def _remove_writer(self, fd):
 | 
			
		|||
    def add_reader(self, fd, callback, *args):
 | 
			
		||||
        """Add a reader callback."""
 | 
			
		||||
        self._ensure_fd_no_transport(fd)
 | 
			
		||||
        return self._add_reader(fd, callback, *args)
 | 
			
		||||
        self._add_reader(fd, callback, *args)
 | 
			
		||||
 | 
			
		||||
    def remove_reader(self, fd):
 | 
			
		||||
        """Remove a reader callback."""
 | 
			
		||||
| 
						 | 
				
			
			@ -339,7 +341,7 @@ def remove_reader(self, fd):
 | 
			
		|||
    def add_writer(self, fd, callback, *args):
 | 
			
		||||
        """Add a writer callback.."""
 | 
			
		||||
        self._ensure_fd_no_transport(fd)
 | 
			
		||||
        return self._add_writer(fd, callback, *args)
 | 
			
		||||
        self._add_writer(fd, callback, *args)
 | 
			
		||||
 | 
			
		||||
    def remove_writer(self, fd):
 | 
			
		||||
        """Remove a writer callback."""
 | 
			
		||||
| 
						 | 
				
			
			@ -362,12 +364,14 @@ async def sock_recv(self, sock, n):
 | 
			
		|||
            pass
 | 
			
		||||
        fut = self.create_future()
 | 
			
		||||
        fd = sock.fileno()
 | 
			
		||||
        self.add_reader(fd, self._sock_recv, fut, sock, n)
 | 
			
		||||
        self._ensure_fd_no_transport(fd)
 | 
			
		||||
        handle = self._add_reader(fd, self._sock_recv, fut, sock, n)
 | 
			
		||||
        fut.add_done_callback(
 | 
			
		||||
            functools.partial(self._sock_read_done, fd))
 | 
			
		||||
            functools.partial(self._sock_read_done, fd, handle=handle))
 | 
			
		||||
        return await fut
 | 
			
		||||
 | 
			
		||||
    def _sock_read_done(self, fd, fut):
 | 
			
		||||
    def _sock_read_done(self, fd, fut, handle=None):
 | 
			
		||||
        if handle is None or not handle.cancelled():
 | 
			
		||||
            self.remove_reader(fd)
 | 
			
		||||
 | 
			
		||||
    def _sock_recv(self, fut, sock, n):
 | 
			
		||||
| 
						 | 
				
			
			@ -401,9 +405,10 @@ async def sock_recv_into(self, sock, buf):
 | 
			
		|||
            pass
 | 
			
		||||
        fut = self.create_future()
 | 
			
		||||
        fd = sock.fileno()
 | 
			
		||||
        self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
 | 
			
		||||
        self._ensure_fd_no_transport(fd)
 | 
			
		||||
        handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf)
 | 
			
		||||
        fut.add_done_callback(
 | 
			
		||||
            functools.partial(self._sock_read_done, fd))
 | 
			
		||||
            functools.partial(self._sock_read_done, fd, handle=handle))
 | 
			
		||||
        return await fut
 | 
			
		||||
 | 
			
		||||
    def _sock_recv_into(self, fut, sock, buf):
 | 
			
		||||
| 
						 | 
				
			
			@ -446,11 +451,12 @@ async def sock_sendall(self, sock, data):
 | 
			
		|||
 | 
			
		||||
        fut = self.create_future()
 | 
			
		||||
        fd = sock.fileno()
 | 
			
		||||
        fut.add_done_callback(
 | 
			
		||||
            functools.partial(self._sock_write_done, fd))
 | 
			
		||||
        self._ensure_fd_no_transport(fd)
 | 
			
		||||
        # use a trick with a list in closure to store a mutable state
 | 
			
		||||
        self.add_writer(fd, self._sock_sendall, fut, sock,
 | 
			
		||||
        handle = self._add_writer(fd, self._sock_sendall, fut, sock,
 | 
			
		||||
                                  memoryview(data), [n])
 | 
			
		||||
        fut.add_done_callback(
 | 
			
		||||
            functools.partial(self._sock_write_done, fd, handle=handle))
 | 
			
		||||
        return await fut
 | 
			
		||||
 | 
			
		||||
    def _sock_sendall(self, fut, sock, view, pos):
 | 
			
		||||
| 
						 | 
				
			
			@ -502,9 +508,11 @@ def _sock_connect(self, fut, sock, address):
 | 
			
		|||
            # connection runs in background. We have to wait until the socket
 | 
			
		||||
            # becomes writable to be notified when the connection succeed or
 | 
			
		||||
            # fails.
 | 
			
		||||
            self._ensure_fd_no_transport(fd)
 | 
			
		||||
            handle = self._add_writer(
 | 
			
		||||
                fd, self._sock_connect_cb, fut, sock, address)
 | 
			
		||||
            fut.add_done_callback(
 | 
			
		||||
                functools.partial(self._sock_write_done, fd))
 | 
			
		||||
            self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
 | 
			
		||||
                functools.partial(self._sock_write_done, fd, handle=handle))
 | 
			
		||||
        except (SystemExit, KeyboardInterrupt):
 | 
			
		||||
            raise
 | 
			
		||||
        except BaseException as exc:
 | 
			
		||||
| 
						 | 
				
			
			@ -512,7 +520,8 @@ def _sock_connect(self, fut, sock, address):
 | 
			
		|||
        else:
 | 
			
		||||
            fut.set_result(None)
 | 
			
		||||
 | 
			
		||||
    def _sock_write_done(self, fd, fut):
 | 
			
		||||
    def _sock_write_done(self, fd, fut, handle=None):
 | 
			
		||||
        if handle is None or not handle.cancelled():
 | 
			
		||||
            self.remove_writer(fd)
 | 
			
		||||
 | 
			
		||||
    def _sock_connect_cb(self, fut, sock, address):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
import socket
 | 
			
		||||
import time
 | 
			
		||||
import asyncio
 | 
			
		||||
import sys
 | 
			
		||||
from asyncio import proactor_events
 | 
			
		||||
| 
						 | 
				
			
			@ -122,6 +123,136 @@ def test_sock_client_ops(self):
 | 
			
		|||
            sock = socket.socket()
 | 
			
		||||
            self._basetest_sock_recv_into(httpd, sock)
 | 
			
		||||
 | 
			
		||||
    async def _basetest_sock_recv_racing(self, httpd, sock):
 | 
			
		||||
        sock.setblocking(False)
 | 
			
		||||
        await self.loop.sock_connect(sock, httpd.address)
 | 
			
		||||
 | 
			
		||||
        task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        task.cancel()
 | 
			
		||||
 | 
			
		||||
        asyncio.create_task(
 | 
			
		||||
            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
 | 
			
		||||
        data = await self.loop.sock_recv(sock, 1024)
 | 
			
		||||
        # consume data
 | 
			
		||||
        await self.loop.sock_recv(sock, 1024)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
 | 
			
		||||
 | 
			
		||||
    async def _basetest_sock_recv_into_racing(self, httpd, sock):
 | 
			
		||||
        sock.setblocking(False)
 | 
			
		||||
        await self.loop.sock_connect(sock, httpd.address)
 | 
			
		||||
 | 
			
		||||
        data = bytearray(1024)
 | 
			
		||||
        with memoryview(data) as buf:
 | 
			
		||||
            task = asyncio.create_task(
 | 
			
		||||
                self.loop.sock_recv_into(sock, buf[:1024]))
 | 
			
		||||
            await asyncio.sleep(0)
 | 
			
		||||
            task.cancel()
 | 
			
		||||
 | 
			
		||||
            task = asyncio.create_task(
 | 
			
		||||
                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
 | 
			
		||||
            nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
 | 
			
		||||
            # consume data
 | 
			
		||||
            await self.loop.sock_recv_into(sock, buf[nbytes:])
 | 
			
		||||
            self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
 | 
			
		||||
 | 
			
		||||
        await task
 | 
			
		||||
 | 
			
		||||
    async def _basetest_sock_send_racing(self, listener, sock):
 | 
			
		||||
        listener.bind(('127.0.0.1', 0))
 | 
			
		||||
        listener.listen(1)
 | 
			
		||||
 | 
			
		||||
        # make connection
 | 
			
		||||
        sock.setblocking(False)
 | 
			
		||||
        task = asyncio.create_task(
 | 
			
		||||
            self.loop.sock_connect(sock, listener.getsockname()))
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        server = listener.accept()[0]
 | 
			
		||||
        server.setblocking(False)
 | 
			
		||||
 | 
			
		||||
        with server:
 | 
			
		||||
            await task
 | 
			
		||||
 | 
			
		||||
            # fill the buffer
 | 
			
		||||
            with self.assertRaises(BlockingIOError):
 | 
			
		||||
                while True:
 | 
			
		||||
                    sock.send(b' ' * 5)
 | 
			
		||||
 | 
			
		||||
            # cancel a blocked sock_sendall
 | 
			
		||||
            task = asyncio.create_task(
 | 
			
		||||
                self.loop.sock_sendall(sock, b'hello'))
 | 
			
		||||
            await asyncio.sleep(0)
 | 
			
		||||
            task.cancel()
 | 
			
		||||
 | 
			
		||||
            # clear the buffer
 | 
			
		||||
            async def recv_until():
 | 
			
		||||
                data = b''
 | 
			
		||||
                while not data:
 | 
			
		||||
                    data = await self.loop.sock_recv(server, 1024)
 | 
			
		||||
                    data = data.strip()
 | 
			
		||||
                return data
 | 
			
		||||
            task = asyncio.create_task(recv_until())
 | 
			
		||||
 | 
			
		||||
            # immediately register another sock_sendall
 | 
			
		||||
            await self.loop.sock_sendall(sock, b'world')
 | 
			
		||||
            data = await task
 | 
			
		||||
            # ProactorEventLoop could deliver hello
 | 
			
		||||
            self.assertTrue(data.endswith(b'world'))
 | 
			
		||||
 | 
			
		||||
    async def _basetest_sock_connect_racing(self, listener, sock):
 | 
			
		||||
        listener.bind(('127.0.0.1', 0))
 | 
			
		||||
        addr = listener.getsockname()
 | 
			
		||||
        sock.setblocking(False)
 | 
			
		||||
 | 
			
		||||
        task = asyncio.create_task(self.loop.sock_connect(sock, addr))
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        task.cancel()
 | 
			
		||||
 | 
			
		||||
        listener.listen(1)
 | 
			
		||||
        i = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                await self.loop.sock_connect(sock, addr)
 | 
			
		||||
                break
 | 
			
		||||
            except ConnectionRefusedError:  # on Linux we need another retry
 | 
			
		||||
                await self.loop.sock_connect(sock, addr)
 | 
			
		||||
                break
 | 
			
		||||
            except OSError as e:  # on Windows we need more retries
 | 
			
		||||
                # A connect request was made on an already connected socket
 | 
			
		||||
                if getattr(e, 'winerror', 0) == 10056:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                # https://stackoverflow.com/a/54437602/3316267
 | 
			
		||||
                if getattr(e, 'winerror', 0) != 10022:
 | 
			
		||||
                    raise
 | 
			
		||||
                i += 1
 | 
			
		||||
                if i >= 128:
 | 
			
		||||
                    raise  # too many retries
 | 
			
		||||
                # avoid touching event loop to maintain race condition
 | 
			
		||||
                time.sleep(0.01)
 | 
			
		||||
 | 
			
		||||
    def test_sock_client_racing(self):
 | 
			
		||||
        with test_utils.run_test_server() as httpd:
 | 
			
		||||
            sock = socket.socket()
 | 
			
		||||
            with sock:
 | 
			
		||||
                self.loop.run_until_complete(asyncio.wait_for(
 | 
			
		||||
                    self._basetest_sock_recv_racing(httpd, sock), 10))
 | 
			
		||||
            sock = socket.socket()
 | 
			
		||||
            with sock:
 | 
			
		||||
                self.loop.run_until_complete(asyncio.wait_for(
 | 
			
		||||
                    self._basetest_sock_recv_into_racing(httpd, sock), 10))
 | 
			
		||||
        listener = socket.socket()
 | 
			
		||||
        sock = socket.socket()
 | 
			
		||||
        with listener, sock:
 | 
			
		||||
            self.loop.run_until_complete(asyncio.wait_for(
 | 
			
		||||
                self._basetest_sock_send_racing(listener, sock), 10))
 | 
			
		||||
        listener = socket.socket()
 | 
			
		||||
        sock = socket.socket()
 | 
			
		||||
        with listener, sock:
 | 
			
		||||
            self.loop.run_until_complete(asyncio.wait_for(
 | 
			
		||||
                self._basetest_sock_connect_racing(listener, sock), 10))
 | 
			
		||||
 | 
			
		||||
    async def _basetest_huge_content(self, address):
 | 
			
		||||
        sock = socket.socket()
 | 
			
		||||
        sock.setblocking(False)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Fix asyncio ``loop.sock_*`` race condition issue
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue