mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	asyncio, Tulip issue 205: Fix a race condition in BaseSelectorEventLoop.sock_connect()
There is a race condition in create_connection() used with wait_for() to have a timeout. sock_connect() registers the file descriptor of the socket to be notified of write event (if connect() raises BlockingIOError). When create_connection() is cancelled with a TimeoutError, sock_connect() coroutine gets the exception, but it doesn't unregister the file descriptor for write event. create_connection() gets the TimeoutError and closes the socket. If you call again create_connection(), the new socket will likely gets the same file descriptor, which is still registered in the selector. When sock_connect() calls add_writer(), it tries to modify the entry instead of creating a new one. This issue was originally reported in the Trollius project, but the bug comes from Tulip in fact (Trollius is based on Tulip): https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for This change fixes the race condition. It also makes sock_connect() more reliable (and portable) is sock.connect() raises an InterruptedError.
This commit is contained in:
		
							parent
							
								
									41f3c3f226
								
							
						
					
					
						commit
						d5aeccf976
					
				
					 2 changed files with 83 additions and 35 deletions
				
			
		| 
						 | 
					@ -8,6 +8,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import collections
 | 
					import collections
 | 
				
			||||||
import errno
 | 
					import errno
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
import socket
 | 
					import socket
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    import ssl
 | 
					    import ssl
 | 
				
			||||||
| 
						 | 
					@ -345,26 +346,43 @@ def sock_connect(self, sock, address):
 | 
				
			||||||
        except ValueError as err:
 | 
					        except ValueError as err:
 | 
				
			||||||
            fut.set_exception(err)
 | 
					            fut.set_exception(err)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self._sock_connect(fut, False, sock, address)
 | 
					            self._sock_connect(fut, sock, address)
 | 
				
			||||||
        return fut
 | 
					        return fut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sock_connect(self, fut, registered, sock, address):
 | 
					    def _sock_connect(self, fut, sock, address):
 | 
				
			||||||
        fd = sock.fileno()
 | 
					        fd = sock.fileno()
 | 
				
			||||||
        if registered:
 | 
					        try:
 | 
				
			||||||
            self.remove_writer(fd)
 | 
					            while True:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    sock.connect(address)
 | 
				
			||||||
 | 
					                except InterruptedError:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					        except BlockingIOError:
 | 
				
			||||||
 | 
					            fut.add_done_callback(functools.partial(self._sock_connect_done,
 | 
				
			||||||
 | 
					                                                    sock))
 | 
				
			||||||
 | 
					            self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
 | 
				
			||||||
 | 
					        except Exception as exc:
 | 
				
			||||||
 | 
					            fut.set_exception(exc)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            fut.set_result(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _sock_connect_done(self, sock, fut):
 | 
				
			||||||
 | 
					        self.remove_writer(sock.fileno())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _sock_connect_cb(self, fut, sock, address):
 | 
				
			||||||
        if fut.cancelled():
 | 
					        if fut.cancelled():
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            if not registered:
 | 
					            err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
 | 
				
			||||||
                # First time around.
 | 
					            if err != 0:
 | 
				
			||||||
                sock.connect(address)
 | 
					                # Jump to any except clause below.
 | 
				
			||||||
            else:
 | 
					                raise OSError(err, 'Connect call failed %s' % (address,))
 | 
				
			||||||
                err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
 | 
					 | 
				
			||||||
                if err != 0:
 | 
					 | 
				
			||||||
                    # Jump to the except clause below.
 | 
					 | 
				
			||||||
                    raise OSError(err, 'Connect call failed %s' % (address,))
 | 
					 | 
				
			||||||
        except (BlockingIOError, InterruptedError):
 | 
					        except (BlockingIOError, InterruptedError):
 | 
				
			||||||
            self.add_writer(fd, self._sock_connect, fut, True, sock, address)
 | 
					            # socket is still registered, the callback will be retried later
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
        except Exception as exc:
 | 
					        except Exception as exc:
 | 
				
			||||||
            fut.set_exception(exc)
 | 
					            fut.set_exception(exc)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,8 +40,9 @@ def list_to_buffer(l=()):
 | 
				
			||||||
class BaseSelectorEventLoopTests(test_utils.TestCase):
 | 
					class BaseSelectorEventLoopTests(test_utils.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        selector = mock.Mock()
 | 
					        self.selector = mock.Mock()
 | 
				
			||||||
        self.loop = TestBaseSelectorEventLoop(selector)
 | 
					        self.selector.select.return_value = []
 | 
				
			||||||
 | 
					        self.loop = TestBaseSelectorEventLoop(self.selector)
 | 
				
			||||||
        self.set_event_loop(self.loop, cleanup=False)
 | 
					        self.set_event_loop(self.loop, cleanup=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_make_socket_transport(self):
 | 
					    def test_make_socket_transport(self):
 | 
				
			||||||
| 
						 | 
					@ -303,63 +304,92 @@ def test_sock_connect(self):
 | 
				
			||||||
        f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
 | 
					        f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
 | 
				
			||||||
        self.assertIsInstance(f, asyncio.Future)
 | 
					        self.assertIsInstance(f, asyncio.Future)
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            (f, False, sock, ('127.0.0.1', 8080)),
 | 
					            (f, sock, ('127.0.0.1', 8080)),
 | 
				
			||||||
            self.loop._sock_connect.call_args[0])
 | 
					            self.loop._sock_connect.call_args[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sock_connect_timeout(self):
 | 
				
			||||||
 | 
					        # Tulip issue #205: sock_connect() must unregister the socket on
 | 
				
			||||||
 | 
					        # timeout error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # prepare mocks
 | 
				
			||||||
 | 
					        self.loop.add_writer = mock.Mock()
 | 
				
			||||||
 | 
					        self.loop.remove_writer = mock.Mock()
 | 
				
			||||||
 | 
					        sock = test_utils.mock_nonblocking_socket()
 | 
				
			||||||
 | 
					        sock.connect.side_effect = BlockingIOError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # first call to sock_connect() registers the socket
 | 
				
			||||||
 | 
					        fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
 | 
				
			||||||
 | 
					        self.assertTrue(sock.connect.called)
 | 
				
			||||||
 | 
					        self.assertTrue(self.loop.add_writer.called)
 | 
				
			||||||
 | 
					        self.assertEqual(len(fut._callbacks), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # on timeout, the socket must be unregistered
 | 
				
			||||||
 | 
					        sock.connect.reset_mock()
 | 
				
			||||||
 | 
					        fut.set_exception(asyncio.TimeoutError)
 | 
				
			||||||
 | 
					        with self.assertRaises(asyncio.TimeoutError):
 | 
				
			||||||
 | 
					            self.loop.run_until_complete(fut)
 | 
				
			||||||
 | 
					        self.assertTrue(self.loop.remove_writer.called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test__sock_connect(self):
 | 
					    def test__sock_connect(self):
 | 
				
			||||||
        f = asyncio.Future(loop=self.loop)
 | 
					        f = asyncio.Future(loop=self.loop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sock = mock.Mock()
 | 
					        sock = mock.Mock()
 | 
				
			||||||
        sock.fileno.return_value = 10
 | 
					        sock.fileno.return_value = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
 | 
					        self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))
 | 
				
			||||||
        self.assertTrue(f.done())
 | 
					        self.assertTrue(f.done())
 | 
				
			||||||
        self.assertIsNone(f.result())
 | 
					        self.assertIsNone(f.result())
 | 
				
			||||||
        self.assertTrue(sock.connect.called)
 | 
					        self.assertTrue(sock.connect.called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test__sock_connect_canceled_fut(self):
 | 
					    def test__sock_connect_cb_cancelled_fut(self):
 | 
				
			||||||
        sock = mock.Mock()
 | 
					        sock = mock.Mock()
 | 
				
			||||||
 | 
					        self.loop.remove_writer = mock.Mock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        f = asyncio.Future(loop=self.loop)
 | 
					        f = asyncio.Future(loop=self.loop)
 | 
				
			||||||
        f.cancel()
 | 
					        f.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
 | 
					        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
 | 
				
			||||||
        self.assertFalse(sock.connect.called)
 | 
					        self.assertFalse(sock.getsockopt.called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test__sock_connect_writer(self):
 | 
				
			||||||
 | 
					        # check that the fd is registered and then unregistered
 | 
				
			||||||
 | 
					        self.loop._process_events = mock.Mock()
 | 
				
			||||||
 | 
					        self.loop.add_writer = mock.Mock()
 | 
				
			||||||
 | 
					        self.loop.remove_writer = mock.Mock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test__sock_connect_unregister(self):
 | 
					 | 
				
			||||||
        sock = mock.Mock()
 | 
					        sock = mock.Mock()
 | 
				
			||||||
        sock.fileno.return_value = 10
 | 
					        sock.fileno.return_value = 10
 | 
				
			||||||
 | 
					        sock.connect.side_effect = BlockingIOError
 | 
				
			||||||
 | 
					        sock.getsockopt.return_value = 0
 | 
				
			||||||
 | 
					        address = ('127.0.0.1', 8080)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        f = asyncio.Future(loop=self.loop)
 | 
					        f = asyncio.Future(loop=self.loop)
 | 
				
			||||||
        f.cancel()
 | 
					        self.loop._sock_connect(f, sock, address)
 | 
				
			||||||
 | 
					        self.assertTrue(self.loop.add_writer.called)
 | 
				
			||||||
 | 
					        self.assertEqual(10, self.loop.add_writer.call_args[0][0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop.remove_writer = mock.Mock()
 | 
					        self.loop._sock_connect_cb(f, sock, address)
 | 
				
			||||||
        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
 | 
					        # need to run the event loop to execute _sock_connect_done() callback
 | 
				
			||||||
 | 
					        self.loop.run_until_complete(f)
 | 
				
			||||||
        self.assertEqual((10,), self.loop.remove_writer.call_args[0])
 | 
					        self.assertEqual((10,), self.loop.remove_writer.call_args[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test__sock_connect_tryagain(self):
 | 
					    def test__sock_connect_cb_tryagain(self):
 | 
				
			||||||
        f = asyncio.Future(loop=self.loop)
 | 
					        f = asyncio.Future(loop=self.loop)
 | 
				
			||||||
        sock = mock.Mock()
 | 
					        sock = mock.Mock()
 | 
				
			||||||
        sock.fileno.return_value = 10
 | 
					        sock.fileno.return_value = 10
 | 
				
			||||||
        sock.getsockopt.return_value = errno.EAGAIN
 | 
					        sock.getsockopt.return_value = errno.EAGAIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop.add_writer = mock.Mock()
 | 
					        # check that the exception is handled
 | 
				
			||||||
        self.loop.remove_writer = mock.Mock()
 | 
					        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
 | 
					    def test__sock_connect_cb_exception(self):
 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            (10, self.loop._sock_connect, f,
 | 
					 | 
				
			||||||
             True, sock, ('127.0.0.1', 8080)),
 | 
					 | 
				
			||||||
            self.loop.add_writer.call_args[0])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test__sock_connect_exception(self):
 | 
					 | 
				
			||||||
        f = asyncio.Future(loop=self.loop)
 | 
					        f = asyncio.Future(loop=self.loop)
 | 
				
			||||||
        sock = mock.Mock()
 | 
					        sock = mock.Mock()
 | 
				
			||||||
        sock.fileno.return_value = 10
 | 
					        sock.fileno.return_value = 10
 | 
				
			||||||
        sock.getsockopt.return_value = errno.ENOTCONN
 | 
					        sock.getsockopt.return_value = errno.ENOTCONN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.loop.remove_writer = mock.Mock()
 | 
					        self.loop.remove_writer = mock.Mock()
 | 
				
			||||||
        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
 | 
					        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
 | 
				
			||||||
        self.assertIsInstance(f.exception(), OSError)
 | 
					        self.assertIsInstance(f.exception(), OSError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_sock_accept(self):
 | 
					    def test_sock_accept(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue