mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	clean up ssl.py; expose unwrap and add test for it
This commit is contained in:
		
							parent
							
								
									6aa2d1fec7
								
							
						
					
					
						commit
						40a0f66e95
					
				
					 3 changed files with 69 additions and 4 deletions
				
			
		
							
								
								
									
										10
									
								
								Lib/ssl.py
									
										
									
									
									
								
							
							
						
						
									
										10
									
								
								Lib/ssl.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -75,10 +75,10 @@
 | 
			
		|||
    SSL_ERROR_INVALID_ERROR_CODE,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
from socket import socket, AF_INET, SOCK_STREAM, error
 | 
			
		||||
from socket import getnameinfo as _getnameinfo
 | 
			
		||||
from socket import error as socket_error
 | 
			
		||||
from socket import dup as _dup
 | 
			
		||||
from socket import socket, AF_INET, SOCK_STREAM
 | 
			
		||||
import base64        # for DER-to-PEM translation
 | 
			
		||||
import traceback
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -296,6 +296,14 @@ def shutdown(self, how):
 | 
			
		|||
        self._sslobj = None
 | 
			
		||||
        socket.shutdown(self, how)
 | 
			
		||||
 | 
			
		||||
    def unwrap (self):
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            s = self._sslobj.shutdown()
 | 
			
		||||
            self._sslobj = None
 | 
			
		||||
            return s
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("No SSL wrapper around " + str(self))
 | 
			
		||||
 | 
			
		||||
    def _real_close(self):
 | 
			
		||||
        self._sslobj = None
 | 
			
		||||
        # self._closed = True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -279,6 +279,15 @@ def run (self):
 | 
			
		|||
                            self.write("OK\n".encode("ASCII", "strict"))
 | 
			
		||||
                            if not self.wrap_conn():
 | 
			
		||||
                                return
 | 
			
		||||
                        elif (self.server.starttls_server and self.sslconn
 | 
			
		||||
                              and amsg.strip() == 'ENDTLS'):
 | 
			
		||||
                            if support.verbose and self.server.connectionchatty:
 | 
			
		||||
                                sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
 | 
			
		||||
                            self.write("OK\n".encode("ASCII", "strict"))
 | 
			
		||||
                            self.sock = self.sslconn.unwrap()
 | 
			
		||||
                            self.sslconn = None
 | 
			
		||||
                            if support.verbose and self.server.connectionchatty:
 | 
			
		||||
                                sys.stdout.write(" server: connection is now unencrypted...\n")
 | 
			
		||||
                        else:
 | 
			
		||||
                            if (support.verbose and
 | 
			
		||||
                                self.server.connectionchatty):
 | 
			
		||||
| 
						 | 
				
			
			@ -868,7 +877,7 @@ def testProtocolTLS1(self):
 | 
			
		|||
 | 
			
		||||
        def testSTARTTLS (self):
 | 
			
		||||
 | 
			
		||||
            msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4")
 | 
			
		||||
            msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
 | 
			
		||||
 | 
			
		||||
            server = ThreadedEchoServer(CERTFILE,
 | 
			
		||||
                                        ssl_version=ssl.PROTOCOL_TLSv1,
 | 
			
		||||
| 
						 | 
				
			
			@ -910,8 +919,16 @@ def testSTARTTLS (self):
 | 
			
		|||
                                    " client:  read %s from server, starting TLS...\n"
 | 
			
		||||
                                    % repr(msg))
 | 
			
		||||
                            conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
 | 
			
		||||
 | 
			
		||||
                            wrapped = True
 | 
			
		||||
                        elif (indata == "ENDTLS" and
 | 
			
		||||
                              str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
 | 
			
		||||
                            if support.verbose:
 | 
			
		||||
                                msg = str(outdata, 'ASCII', 'replace')
 | 
			
		||||
                                sys.stdout.write(
 | 
			
		||||
                                    " client:  read %s from server, ending TLS...\n"
 | 
			
		||||
                                    % repr(msg))
 | 
			
		||||
                            s = conn.unwrap()
 | 
			
		||||
                            wrapped = False
 | 
			
		||||
                        else:
 | 
			
		||||
                            if support.verbose:
 | 
			
		||||
                                msg = str(outdata, 'ASCII', 'replace')
 | 
			
		||||
| 
						 | 
				
			
			@ -922,7 +939,7 @@ def testSTARTTLS (self):
 | 
			
		|||
                    if wrapped:
 | 
			
		||||
                        conn.write("over\n".encode("ASCII", "strict"))
 | 
			
		||||
                    else:
 | 
			
		||||
                        s.send("over\n")
 | 
			
		||||
                        s.send("over\n".encode("ASCII", "strict"))
 | 
			
		||||
                if wrapped:
 | 
			
		||||
                    conn.close()
 | 
			
		||||
                else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1370,6 +1370,42 @@ PyDoc_STRVAR(PySSL_SSLread_doc,
 | 
			
		|||
\n\
 | 
			
		||||
Read up to len bytes from the SSL socket.");
 | 
			
		||||
 | 
			
		||||
static PyObject *PySSL_SSLshutdown(PySSLObject *self)
 | 
			
		||||
{
 | 
			
		||||
	int err;
 | 
			
		||||
        PySocketSockObject *sock
 | 
			
		||||
          = (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
 | 
			
		||||
 | 
			
		||||
	/* Guard against closed socket */
 | 
			
		||||
        if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) {
 | 
			
		||||
                _setSSLError("Underlying socket connection gone",
 | 
			
		||||
                             PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
 | 
			
		||||
                return NULL;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
	PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
	err = SSL_shutdown(self->ssl);
 | 
			
		||||
	if (err == 0) {
 | 
			
		||||
		/* we need to call it again to finish the shutdown */
 | 
			
		||||
		err = SSL_shutdown(self->ssl);
 | 
			
		||||
	}
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
 | 
			
		||||
	if (err < 0)
 | 
			
		||||
		return PySSL_SetError(self, err, __FILE__, __LINE__);
 | 
			
		||||
	else {
 | 
			
		||||
                Py_INCREF(sock);
 | 
			
		||||
                return (PyObject *) sock;
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PyDoc_STRVAR(PySSL_SSLshutdown_doc,
 | 
			
		||||
"shutdown(s) -> socket\n\
 | 
			
		||||
\n\
 | 
			
		||||
Does the SSL shutdown handshake with the remote end, and returns\n\
 | 
			
		||||
the underlying socket object.");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static PyMethodDef PySSLMethods[] = {
 | 
			
		||||
	{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
 | 
			
		||||
	{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
 | 
			
		||||
| 
						 | 
				
			
			@ -1381,6 +1417,8 @@ static PyMethodDef PySSLMethods[] = {
 | 
			
		|||
	{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
 | 
			
		||||
	 PySSL_peercert_doc},
 | 
			
		||||
	{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
 | 
			
		||||
	{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
 | 
			
		||||
	 PySSL_SSLshutdown_doc},
 | 
			
		||||
	{NULL, NULL}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1480,6 +1518,8 @@ fails or if it does provide enough data to seed PRNG.");
 | 
			
		|||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/* List of functions exported by this module. */
 | 
			
		||||
 | 
			
		||||
static PyMethodDef PySSL_methods[] = {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue