mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	Issue #11326: Add the missing connect_ex() implementation for SSL sockets,
and make it work for non-blocking connects.
This commit is contained in:
		
							parent
							
								
									2e7965e8b0
								
							
						
					
					
						commit
						e93bf7aed2
					
				
					 3 changed files with 68 additions and 8 deletions
				
			
		
							
								
								
									
										28
									
								
								Lib/ssl.py
									
										
									
									
									
								
							
							
						
						
									
										28
									
								
								Lib/ssl.py
									
										
									
									
									
								
							| 
						 | 
					@ -237,6 +237,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._closed = False
 | 
					        self._closed = False
 | 
				
			||||||
        self._sslobj = None
 | 
					        self._sslobj = None
 | 
				
			||||||
 | 
					        self._connected = connected
 | 
				
			||||||
        if connected:
 | 
					        if connected:
 | 
				
			||||||
            # create the SSL object
 | 
					            # create the SSL object
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -430,23 +431,36 @@ def do_handshake(self, block=False):
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            self.settimeout(timeout)
 | 
					            self.settimeout(timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def connect(self, addr):
 | 
					    def _real_connect(self, addr, return_errno):
 | 
				
			||||||
        """Connects to remote ADDR, and then wraps the connection in
 | 
					 | 
				
			||||||
        an SSL channel."""
 | 
					 | 
				
			||||||
        if self.server_side:
 | 
					        if self.server_side:
 | 
				
			||||||
            raise ValueError("can't connect in server-side mode")
 | 
					            raise ValueError("can't connect in server-side mode")
 | 
				
			||||||
        # Here we assume that the socket is client-side, and not
 | 
					        # Here we assume that the socket is client-side, and not
 | 
				
			||||||
        # connected at the time of the call.  We connect it, then wrap it.
 | 
					        # connected at the time of the call.  We connect it, then wrap it.
 | 
				
			||||||
        if self._sslobj:
 | 
					        if self._connected:
 | 
				
			||||||
            raise ValueError("attempt to connect already-connected SSLSocket!")
 | 
					            raise ValueError("attempt to connect already-connected SSLSocket!")
 | 
				
			||||||
        socket.connect(self, addr)
 | 
					 | 
				
			||||||
        self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
 | 
					        self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            socket.connect(self, addr)
 | 
				
			||||||
            if self.do_handshake_on_connect:
 | 
					            if self.do_handshake_on_connect:
 | 
				
			||||||
                self.do_handshake()
 | 
					                self.do_handshake()
 | 
				
			||||||
        except:
 | 
					        except socket_error as e:
 | 
				
			||||||
 | 
					            if return_errno:
 | 
				
			||||||
 | 
					                return e.errno
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
                self._sslobj = None
 | 
					                self._sslobj = None
 | 
				
			||||||
            raise
 | 
					                raise e
 | 
				
			||||||
 | 
					        self._connected = True
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connect(self, addr):
 | 
				
			||||||
 | 
					        """Connects to remote ADDR, and then wraps the connection in
 | 
				
			||||||
 | 
					        an SSL channel."""
 | 
				
			||||||
 | 
					        self._real_connect(addr, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connect_ex(self, addr):
 | 
				
			||||||
 | 
					        """Connects to remote ADDR, and then wraps the connection in
 | 
				
			||||||
 | 
					        an SSL channel."""
 | 
				
			||||||
 | 
					        return self._real_connect(addr, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def accept(self):
 | 
					    def accept(self):
 | 
				
			||||||
        """Accepts a new connection from a remote client, and returns
 | 
					        """Accepts a new connection from a remote client, and returns
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -451,6 +451,49 @@ def test_connect(self):
 | 
				
			||||||
            finally:
 | 
					            finally:
 | 
				
			||||||
                s.close()
 | 
					                s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_connect_ex(self):
 | 
				
			||||||
 | 
					        # Issue #11326: check connect_ex() implementation
 | 
				
			||||||
 | 
					        with support.transient_internet("svn.python.org"):
 | 
				
			||||||
 | 
					            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
 | 
				
			||||||
 | 
					                                cert_reqs=ssl.CERT_REQUIRED,
 | 
				
			||||||
 | 
					                                ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                self.assertEqual(0, s.connect_ex(("svn.python.org", 443)))
 | 
				
			||||||
 | 
					                self.assertTrue(s.getpeercert())
 | 
				
			||||||
 | 
					            finally:
 | 
				
			||||||
 | 
					                s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_non_blocking_connect_ex(self):
 | 
				
			||||||
 | 
					        # Issue #11326: non-blocking connect_ex() should allow handshake
 | 
				
			||||||
 | 
					        # to proceed after the socket gets ready.
 | 
				
			||||||
 | 
					        with support.transient_internet("svn.python.org"):
 | 
				
			||||||
 | 
					            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
 | 
				
			||||||
 | 
					                                cert_reqs=ssl.CERT_REQUIRED,
 | 
				
			||||||
 | 
					                                ca_certs=SVN_PYTHON_ORG_ROOT_CERT,
 | 
				
			||||||
 | 
					                                do_handshake_on_connect=False)
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                s.setblocking(False)
 | 
				
			||||||
 | 
					                rc = s.connect_ex(('svn.python.org', 443))
 | 
				
			||||||
 | 
					                self.assertIn(rc, (0, errno.EINPROGRESS))
 | 
				
			||||||
 | 
					                # Wait for connect to finish
 | 
				
			||||||
 | 
					                select.select([], [s], [], 5.0)
 | 
				
			||||||
 | 
					                # Non-blocking handshake
 | 
				
			||||||
 | 
					                while True:
 | 
				
			||||||
 | 
					                    try:
 | 
				
			||||||
 | 
					                        s.do_handshake()
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					                    except ssl.SSLError as err:
 | 
				
			||||||
 | 
					                        if err.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
				
			||||||
 | 
					                            select.select([s], [], [], 5.0)
 | 
				
			||||||
 | 
					                        elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
 | 
				
			||||||
 | 
					                            select.select([], [s], [], 5.0)
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            raise
 | 
				
			||||||
 | 
					                # SSL established
 | 
				
			||||||
 | 
					                self.assertTrue(s.getpeercert())
 | 
				
			||||||
 | 
					            finally:
 | 
				
			||||||
 | 
					                s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_connect_with_context(self):
 | 
					    def test_connect_with_context(self):
 | 
				
			||||||
        with support.transient_internet("svn.python.org"):
 | 
					        with support.transient_internet("svn.python.org"):
 | 
				
			||||||
            # Same as test_connect, but with a separately created context
 | 
					            # Same as test_connect, but with a separately created context
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,6 +35,9 @@ Core and Builtins
 | 
				
			||||||
Library
 | 
					Library
 | 
				
			||||||
-------
 | 
					-------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Issue #11326: Add the missing connect_ex() implementation for SSL sockets,
 | 
				
			||||||
 | 
					  and make it work for non-blocking connects.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- Issue #11297: Add collections.ChainMap().
 | 
					- Issue #11297: Add collections.ChainMap().
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- Issue #10755: Add the posix.fdlistdir() function.  Patch by Ross Lagerwall.
 | 
					- Issue #10755: Add the posix.fdlistdir() function.  Patch by Ross Lagerwall.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue