mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	various SSL fixes; issues 1251, 3162, 3212
This commit is contained in:
		
							parent
							
								
									a27474c345
								
							
						
					
					
						commit
						934b16d0c2
					
				
					 5 changed files with 528 additions and 348 deletions
				
			
		|  | @ -54,7 +54,7 @@ Functions, Constants, and Exceptions | |||
|    network connection.  This error is a subtype of :exc:`socket.error`, which | ||||
|    in turn is a subtype of :exc:`IOError`. | ||||
| 
 | ||||
| .. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None) | ||||
| .. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None, do_handshake_on_connect=True, suppress_ragged_eofs=True) | ||||
| 
 | ||||
|    Takes an instance ``sock`` of :class:`socket.socket`, and returns an instance of :class:`ssl.SSLSocket`, a subtype | ||||
|    of :class:`socket.socket`, which wraps the underlying socket in an SSL context. | ||||
|  | @ -122,6 +122,18 @@ Functions, Constants, and Exceptions | |||
|    In some older versions of OpenSSL (for instance, 0.9.7l on OS X 10.4), | ||||
|    an SSLv2 client could not connect to an SSLv23 server. | ||||
| 
 | ||||
|    The parameter ``do_handshake_on_connect`` specifies whether to do the SSL | ||||
|    handshake automatically after doing a :meth:`socket.connect`, or whether the | ||||
|    application program will call it explicitly, by invoking the :meth:`SSLSocket.do_handshake` | ||||
|    method.  Calling :meth:`SSLSocket.do_handshake` explicitly gives the program control over | ||||
|    the blocking behavior of the socket I/O involved in the handshake. | ||||
| 
 | ||||
|    The parameter ``suppress_ragged_eofs`` specifies how the :meth:`SSLSocket.read` | ||||
|    method should signal unexpected EOF from the other end of the connection.  If specified | ||||
|    as :const:`True` (the default), it returns a normal EOF in response to unexpected | ||||
|    EOF errors raised from the underlying socket; if :const:`False`, it will raise | ||||
|    the exceptions back to the caller. | ||||
| 
 | ||||
| .. function:: RAND_status() | ||||
| 
 | ||||
|    Returns True if the SSL pseudo-random number generator has been | ||||
|  | @ -290,6 +302,25 @@ SSLSocket Objects | |||
|    number of secret bits being used.  If no connection has been | ||||
|    established, returns ``None``. | ||||
| 
 | ||||
| .. method:: SSLSocket.do_handshake() | ||||
| 
 | ||||
|    Perform a TLS/SSL handshake.  If this is used with a non-blocking socket, | ||||
|    it may raise :exc:`SSLError` with an ``arg[0]`` of :const:`SSL_ERROR_WANT_READ` | ||||
|    or :const:`SSL_ERROR_WANT_WRITE`, in which case it must be called again until it | ||||
|    completes successfully.  For example, to simulate the behavior of a blocking socket, | ||||
|    one might write:: | ||||
| 
 | ||||
|         while True: | ||||
|             try: | ||||
|                 s.do_handshake() | ||||
|                 break | ||||
|             except ssl.SSLError, err: | ||||
|                 if err.args[0] == ssl.SSL_ERROR_WANT_READ: | ||||
|                     select.select([s], [], []) | ||||
|                 elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: | ||||
|                     select.select([], [s], []) | ||||
|                 else: | ||||
|                     raise | ||||
| 
 | ||||
| .. index:: single: certificates | ||||
| 
 | ||||
|  | @ -367,6 +398,7 @@ certificate, you need to provide a "CA certs" file, filled with the certificate | |||
| chains for each issuer you are willing to trust.  Again, this file just | ||||
| contains these chains concatenated together.  For validation, Python will | ||||
| use the first chain it finds in the file which matches. | ||||
| 
 | ||||
| Some "standard" root certificates are available from various certification | ||||
| authorities: | ||||
| `CACert.org <http://www.cacert.org/index.php?id=3>`_, | ||||
|  |  | |||
							
								
								
									
										353
									
								
								Lib/ssl.py
									
										
									
									
									
								
							
							
						
						
									
										353
									
								
								Lib/ssl.py
									
										
									
									
									
								
							|  | @ -74,7 +74,7 @@ | |||
|      SSL_ERROR_EOF, \ | ||||
|      SSL_ERROR_INVALID_ERROR_CODE | ||||
| 
 | ||||
| from socket import socket | ||||
| from socket import socket, _fileobject | ||||
| from socket import getnameinfo as _getnameinfo | ||||
| import base64        # for DER-to-PEM translation | ||||
| 
 | ||||
|  | @ -86,8 +86,16 @@ class SSLSocket (socket): | |||
| 
 | ||||
|     def __init__(self, sock, keyfile=None, certfile=None, | ||||
|                  server_side=False, cert_reqs=CERT_NONE, | ||||
|                  ssl_version=PROTOCOL_SSLv23, ca_certs=None): | ||||
|                  ssl_version=PROTOCOL_SSLv23, ca_certs=None, | ||||
|                  do_handshake_on_connect=True, | ||||
|                  suppress_ragged_eofs=True): | ||||
|         socket.__init__(self, _sock=sock._sock) | ||||
|         # the initializer for socket trashes the methods (tsk, tsk), so... | ||||
|         self.send = lambda x, flags=0: SSLSocket.send(self, x, flags) | ||||
|         self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags) | ||||
|         self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags) | ||||
|         self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags) | ||||
| 
 | ||||
|         if certfile and not keyfile: | ||||
|             keyfile = certfile | ||||
|         # see if it's connected | ||||
|  | @ -101,18 +109,34 @@ def __init__(self, sock, keyfile=None, certfile=None, | |||
|             self._sslobj = _ssl.sslwrap(self._sock, server_side, | ||||
|                                         keyfile, certfile, | ||||
|                                         cert_reqs, ssl_version, ca_certs) | ||||
|             if do_handshake_on_connect: | ||||
|                 timeout = self.gettimeout() | ||||
|                 try: | ||||
|                     self.settimeout(None) | ||||
|                     self.do_handshake() | ||||
|                 finally: | ||||
|                     self.settimeout(timeout) | ||||
|         self.keyfile = keyfile | ||||
|         self.certfile = certfile | ||||
|         self.cert_reqs = cert_reqs | ||||
|         self.ssl_version = ssl_version | ||||
|         self.ca_certs = ca_certs | ||||
|         self.do_handshake_on_connect = do_handshake_on_connect | ||||
|         self.suppress_ragged_eofs = suppress_ragged_eofs | ||||
|         self._makefile_refs = 0 | ||||
| 
 | ||||
|     def read(self, len=1024): | ||||
| 
 | ||||
|         """Read up to LEN bytes and return them. | ||||
|         Return zero-length string on EOF.""" | ||||
| 
 | ||||
|         try: | ||||
|             return self._sslobj.read(len) | ||||
|         except SSLError, x: | ||||
|             if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: | ||||
|                 return '' | ||||
|             else: | ||||
|                 raise | ||||
| 
 | ||||
|     def write(self, data): | ||||
| 
 | ||||
|  | @ -143,16 +167,27 @@ def send (self, data, flags=0): | |||
|                 raise ValueError( | ||||
|                     "non-zero flags not allowed in calls to send() on %s" % | ||||
|                     self.__class__) | ||||
|             return self._sslobj.write(data) | ||||
|             while True: | ||||
|                 try: | ||||
|                     v = self._sslobj.write(data) | ||||
|                 except SSLError, x: | ||||
|                     if x.args[0] == SSL_ERROR_WANT_READ: | ||||
|                         return 0 | ||||
|                     elif x.args[0] == SSL_ERROR_WANT_WRITE: | ||||
|                         return 0 | ||||
|                     else: | ||||
|                         raise | ||||
|                 else: | ||||
|                     return v | ||||
|         else: | ||||
|             return socket.send(self, data, flags) | ||||
| 
 | ||||
|     def send_to (self, data, addr, flags=0): | ||||
|     def sendto (self, data, addr, flags=0): | ||||
|         if self._sslobj: | ||||
|             raise ValueError("send_to not allowed on instances of %s" % | ||||
|             raise ValueError("sendto not allowed on instances of %s" % | ||||
|                              self.__class__) | ||||
|         else: | ||||
|             return socket.send_to(self, data, addr, flags) | ||||
|             return socket.sendto(self, data, addr, flags) | ||||
| 
 | ||||
|     def sendall (self, data, flags=0): | ||||
|         if self._sslobj: | ||||
|  | @ -160,7 +195,12 @@ def sendall (self, data, flags=0): | |||
|                 raise ValueError( | ||||
|                     "non-zero flags not allowed in calls to sendall() on %s" % | ||||
|                     self.__class__) | ||||
|             return self._sslobj.write(data) | ||||
|             amount = len(data) | ||||
|             count = 0 | ||||
|             while (count < amount): | ||||
|                 v = self.send(data[count:]) | ||||
|                 count += v | ||||
|             return amount | ||||
|         else: | ||||
|             return socket.sendall(self, data, flags) | ||||
| 
 | ||||
|  | @ -170,16 +210,29 @@ def recv (self, buflen=1024, flags=0): | |||
|                 raise ValueError( | ||||
|                     "non-zero flags not allowed in calls to sendall() on %s" % | ||||
|                     self.__class__) | ||||
|             return self._sslobj.read(data, buflen) | ||||
|             while True: | ||||
|                 try: | ||||
|                     return self.read(buflen) | ||||
|                 except SSLError, x: | ||||
|                     if x.args[0] == SSL_ERROR_WANT_READ: | ||||
|                         continue | ||||
|                     else: | ||||
|                         raise x | ||||
|         else: | ||||
|             return socket.recv(self, buflen, flags) | ||||
| 
 | ||||
|     def recv_from (self, addr, buflen=1024, flags=0): | ||||
|     def recvfrom (self, addr, buflen=1024, flags=0): | ||||
|         if self._sslobj: | ||||
|             raise ValueError("recv_from not allowed on instances of %s" % | ||||
|             raise ValueError("recvfrom not allowed on instances of %s" % | ||||
|                              self.__class__) | ||||
|         else: | ||||
|             return socket.recv_from(self, addr, buflen, flags) | ||||
|             return socket.recvfrom(self, addr, buflen, flags) | ||||
| 
 | ||||
|     def pending (self): | ||||
|         if self._sslobj: | ||||
|             return self._sslobj.pending() | ||||
|         else: | ||||
|             return 0 | ||||
| 
 | ||||
|     def shutdown (self, how): | ||||
|         self._sslobj = None | ||||
|  | @ -189,6 +242,19 @@ def close(self): | |||
|         self._sslobj = None | ||||
|         socket.close(self) | ||||
| 
 | ||||
|     def close (self): | ||||
|         if self._makefile_refs < 1: | ||||
|             self._sslobj = None | ||||
|             socket.close(self) | ||||
|         else: | ||||
|             self._makefile_refs -= 1 | ||||
| 
 | ||||
|     def do_handshake (self): | ||||
| 
 | ||||
|         """Perform a TLS/SSL handshake.""" | ||||
| 
 | ||||
|         self._sslobj.do_handshake() | ||||
| 
 | ||||
|     def connect(self, addr): | ||||
| 
 | ||||
|         """Connects to remote ADDR, and then wraps the connection in | ||||
|  | @ -202,6 +268,8 @@ def connect(self, addr): | |||
|         self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, | ||||
|                                     self.cert_reqs, self.ssl_version, | ||||
|                                     self.ca_certs) | ||||
|         if self.do_handshake_on_connect: | ||||
|             self.do_handshake() | ||||
| 
 | ||||
|     def accept(self): | ||||
| 
 | ||||
|  | @ -210,260 +278,39 @@ def accept(self): | |||
|         SSL channel, and the address of the remote client.""" | ||||
| 
 | ||||
|         newsock, addr = socket.accept(self) | ||||
|         return (SSLSocket(newsock, True, self.keyfile, self.certfile, | ||||
|                           self.cert_reqs, self.ssl_version, | ||||
|                           self.ca_certs), addr) | ||||
| 
 | ||||
|         return (SSLSocket(newsock, | ||||
|                           keyfile=self.keyfile, | ||||
|                           certfile=self.certfile, | ||||
|                           server_side=True, | ||||
|                           cert_reqs=self.cert_reqs, | ||||
|                           ssl_version=self.ssl_version, | ||||
|                           ca_certs=self.ca_certs, | ||||
|                           do_handshake_on_connect=self.do_handshake_on_connect, | ||||
|                           suppress_ragged_eofs=self.suppress_ragged_eofs), | ||||
|                 addr) | ||||
| 
 | ||||
|     def makefile(self, mode='r', bufsize=-1): | ||||
| 
 | ||||
|         """Ouch.  Need to make and return a file-like object that | ||||
|         works with the SSL connection.""" | ||||
| 
 | ||||
|         if self._sslobj: | ||||
|             return SSLFileStream(self._sslobj, mode, bufsize) | ||||
|         else: | ||||
|             return socket.makefile(self, mode, bufsize) | ||||
| 
 | ||||
| 
 | ||||
| class SSLFileStream: | ||||
| 
 | ||||
|     """A class to simulate a file stream on top of a socket. | ||||
|     Most of this is just lifted from the socket module, and | ||||
|     adjusted to work with an SSL stream instead of a socket.""" | ||||
| 
 | ||||
| 
 | ||||
|     default_bufsize = 8192 | ||||
|     name = "<SSL stream>" | ||||
| 
 | ||||
|     __slots__ = ["mode", "bufsize", "softspace", | ||||
|                  # "closed" is a property, see below | ||||
|                  "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", | ||||
|                  "_close", "_fileno"] | ||||
| 
 | ||||
|     def __init__(self, sslobj, mode='rb', bufsize=-1, close=False): | ||||
|         self._sslobj = sslobj | ||||
|         self.mode = mode # Not actually used in this version | ||||
|         if bufsize < 0: | ||||
|             bufsize = self.default_bufsize | ||||
|         self.bufsize = bufsize | ||||
|         self.softspace = False | ||||
|         if bufsize == 0: | ||||
|             self._rbufsize = 1 | ||||
|         elif bufsize == 1: | ||||
|             self._rbufsize = self.default_bufsize | ||||
|         else: | ||||
|             self._rbufsize = bufsize | ||||
|         self._wbufsize = bufsize | ||||
|         self._rbuf = "" # A string | ||||
|         self._wbuf = [] # A list of strings | ||||
|         self._close = close | ||||
|         self._fileno = -1 | ||||
| 
 | ||||
|     def _getclosed(self): | ||||
|         return self._sslobj is None | ||||
|     closed = property(_getclosed, doc="True if the file is closed") | ||||
| 
 | ||||
|     def fileno(self): | ||||
|         return self._fileno | ||||
| 
 | ||||
|     def close(self): | ||||
|         try: | ||||
|             if self._sslobj: | ||||
|                 self.flush() | ||||
|         finally: | ||||
|             if self._close and self._sslobj: | ||||
|                 self._sslobj.close() | ||||
|             self._sslobj = None | ||||
| 
 | ||||
|     def __del__(self): | ||||
|         try: | ||||
|             self.close() | ||||
|         except: | ||||
|             # close() may fail if __init__ didn't complete | ||||
|             pass | ||||
| 
 | ||||
|     def flush(self): | ||||
|         if self._wbuf: | ||||
|             buffer = "".join(self._wbuf) | ||||
|             self._wbuf = [] | ||||
|             count = 0 | ||||
|             while (count < len(buffer)): | ||||
|                 written = self._sslobj.write(buffer) | ||||
|                 count += written | ||||
|                 buffer = buffer[written:] | ||||
| 
 | ||||
|     def write(self, data): | ||||
|         data = str(data) # XXX Should really reject non-string non-buffers | ||||
|         if not data: | ||||
|             return | ||||
|         self._wbuf.append(data) | ||||
|         if (self._wbufsize == 0 or | ||||
|             self._wbufsize == 1 and '\n' in data or | ||||
|             self._get_wbuf_len() >= self._wbufsize): | ||||
|             self.flush() | ||||
| 
 | ||||
|     def writelines(self, list): | ||||
|         # XXX We could do better here for very long lists | ||||
|         # XXX Should really reject non-string non-buffers | ||||
|         self._wbuf.extend(filter(None, map(str, list))) | ||||
|         if (self._wbufsize <= 1 or | ||||
|             self._get_wbuf_len() >= self._wbufsize): | ||||
|             self.flush() | ||||
| 
 | ||||
|     def _get_wbuf_len(self): | ||||
|         buf_len = 0 | ||||
|         for x in self._wbuf: | ||||
|             buf_len += len(x) | ||||
|         return buf_len | ||||
| 
 | ||||
|     def read(self, size=-1): | ||||
|         data = self._rbuf | ||||
|         if size < 0: | ||||
|             # Read until EOF | ||||
|             buffers = [] | ||||
|             if data: | ||||
|                 buffers.append(data) | ||||
|             self._rbuf = "" | ||||
|             if self._rbufsize <= 1: | ||||
|                 recv_size = self.default_bufsize | ||||
|             else: | ||||
|                 recv_size = self._rbufsize | ||||
|             while True: | ||||
|                 data = self._sslobj.read(recv_size) | ||||
|                 if not data: | ||||
|                     break | ||||
|                 buffers.append(data) | ||||
|             return "".join(buffers) | ||||
|         else: | ||||
|             # Read until size bytes or EOF seen, whichever comes first | ||||
|             buf_len = len(data) | ||||
|             if buf_len >= size: | ||||
|                 self._rbuf = data[size:] | ||||
|                 return data[:size] | ||||
|             buffers = [] | ||||
|             if data: | ||||
|                 buffers.append(data) | ||||
|             self._rbuf = "" | ||||
|             while True: | ||||
|                 left = size - buf_len | ||||
|                 recv_size = max(self._rbufsize, left) | ||||
|                 data = self._sslobj.read(recv_size) | ||||
|                 if not data: | ||||
|                     break | ||||
|                 buffers.append(data) | ||||
|                 n = len(data) | ||||
|                 if n >= left: | ||||
|                     self._rbuf = data[left:] | ||||
|                     buffers[-1] = data[:left] | ||||
|                     break | ||||
|                 buf_len += n | ||||
|             return "".join(buffers) | ||||
| 
 | ||||
|     def readline(self, size=-1): | ||||
|         data = self._rbuf | ||||
|         if size < 0: | ||||
|             # Read until \n or EOF, whichever comes first | ||||
|             if self._rbufsize <= 1: | ||||
|                 # Speed up unbuffered case | ||||
|                 assert data == "" | ||||
|                 buffers = [] | ||||
|                 while data != "\n": | ||||
|                     data = self._sslobj.read(1) | ||||
|                     if not data: | ||||
|                         break | ||||
|                     buffers.append(data) | ||||
|                 return "".join(buffers) | ||||
|             nl = data.find('\n') | ||||
|             if nl >= 0: | ||||
|                 nl += 1 | ||||
|                 self._rbuf = data[nl:] | ||||
|                 return data[:nl] | ||||
|             buffers = [] | ||||
|             if data: | ||||
|                 buffers.append(data) | ||||
|             self._rbuf = "" | ||||
|             while True: | ||||
|                 data = self._sslobj.read(self._rbufsize) | ||||
|                 if not data: | ||||
|                     break | ||||
|                 buffers.append(data) | ||||
|                 nl = data.find('\n') | ||||
|                 if nl >= 0: | ||||
|                     nl += 1 | ||||
|                     self._rbuf = data[nl:] | ||||
|                     buffers[-1] = data[:nl] | ||||
|                     break | ||||
|             return "".join(buffers) | ||||
|         else: | ||||
|             # Read until size bytes or \n or EOF seen, whichever comes first | ||||
|             nl = data.find('\n', 0, size) | ||||
|             if nl >= 0: | ||||
|                 nl += 1 | ||||
|                 self._rbuf = data[nl:] | ||||
|                 return data[:nl] | ||||
|             buf_len = len(data) | ||||
|             if buf_len >= size: | ||||
|                 self._rbuf = data[size:] | ||||
|                 return data[:size] | ||||
|             buffers = [] | ||||
|             if data: | ||||
|                 buffers.append(data) | ||||
|             self._rbuf = "" | ||||
|             while True: | ||||
|                 data = self._sslobj.read(self._rbufsize) | ||||
|                 if not data: | ||||
|                     break | ||||
|                 buffers.append(data) | ||||
|                 left = size - buf_len | ||||
|                 nl = data.find('\n', 0, left) | ||||
|                 if nl >= 0: | ||||
|                     nl += 1 | ||||
|                     self._rbuf = data[nl:] | ||||
|                     buffers[-1] = data[:nl] | ||||
|                     break | ||||
|                 n = len(data) | ||||
|                 if n >= left: | ||||
|                     self._rbuf = data[left:] | ||||
|                     buffers[-1] = data[:left] | ||||
|                     break | ||||
|                 buf_len += n | ||||
|             return "".join(buffers) | ||||
| 
 | ||||
|     def readlines(self, sizehint=0): | ||||
|         total = 0 | ||||
|         list = [] | ||||
|         while True: | ||||
|             line = self.readline() | ||||
|             if not line: | ||||
|                 break | ||||
|             list.append(line) | ||||
|             total += len(line) | ||||
|             if sizehint and total >= sizehint: | ||||
|                 break | ||||
|         return list | ||||
| 
 | ||||
|     # Iterator protocols | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def next(self): | ||||
|         line = self.readline() | ||||
|         if not line: | ||||
|             raise StopIteration | ||||
|         return line | ||||
| 
 | ||||
|         self._makefile_refs += 1 | ||||
|         return _fileobject(self, mode, bufsize) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def wrap_socket(sock, keyfile=None, certfile=None, | ||||
|                 server_side=False, cert_reqs=CERT_NONE, | ||||
|                 ssl_version=PROTOCOL_SSLv23, ca_certs=None): | ||||
|                 ssl_version=PROTOCOL_SSLv23, ca_certs=None, | ||||
|                 do_handshake_on_connect=True, | ||||
|                 suppress_ragged_eofs=True): | ||||
| 
 | ||||
|     return SSLSocket(sock, keyfile=keyfile, certfile=certfile, | ||||
|                      server_side=server_side, cert_reqs=cert_reqs, | ||||
|                      ssl_version=ssl_version, ca_certs=ca_certs) | ||||
|                      ssl_version=ssl_version, ca_certs=ca_certs, | ||||
|                      do_handshake_on_connect=do_handshake_on_connect, | ||||
|                      suppress_ragged_eofs=suppress_ragged_eofs) | ||||
| 
 | ||||
| 
 | ||||
| # some utility functions | ||||
| 
 | ||||
|  | @ -549,5 +396,7 @@ def sslwrap_simple (sock, keyfile=None, certfile=None): | |||
|     for compability with Python 2.5 and earlier.  Will disappear in | ||||
|     Python 3.0.""" | ||||
| 
 | ||||
|     return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, | ||||
|     ssl_sock = _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, | ||||
|                             PROTOCOL_SSLv23, None) | ||||
|     ssl_sock.do_handshake() | ||||
|     return ssl_sock | ||||
|  |  | |||
|  | @ -3,7 +3,9 @@ | |||
| import sys | ||||
| import unittest | ||||
| from test import test_support | ||||
| import asyncore | ||||
| import socket | ||||
| import select | ||||
| import errno | ||||
| import subprocess | ||||
| import time | ||||
|  | @ -97,8 +99,7 @@ def testDERtoPEM(self): | |||
|         if (d1 != d2): | ||||
|             raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") | ||||
| 
 | ||||
| 
 | ||||
| class NetworkTests(unittest.TestCase): | ||||
| class NetworkedTests(unittest.TestCase): | ||||
| 
 | ||||
|     def testConnect(self): | ||||
|         s = ssl.wrap_socket(socket.socket(socket.AF_INET), | ||||
|  | @ -130,6 +131,31 @@ def testConnect(self): | |||
|         finally: | ||||
|             s.close() | ||||
| 
 | ||||
| 
 | ||||
|     def testNonBlockingHandshake(self): | ||||
|         s = socket.socket(socket.AF_INET) | ||||
|         s.connect(("svn.python.org", 443)) | ||||
|         s.setblocking(False) | ||||
|         s = ssl.wrap_socket(s, | ||||
|                             cert_reqs=ssl.CERT_NONE, | ||||
|                             do_handshake_on_connect=False) | ||||
|         count = 0 | ||||
|         while True: | ||||
|             try: | ||||
|                 count += 1 | ||||
|                 s.do_handshake() | ||||
|                 break | ||||
|             except ssl.SSLError, err: | ||||
|                 if err.args[0] == ssl.SSL_ERROR_WANT_READ: | ||||
|                     select.select([s], [], []) | ||||
|                 elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: | ||||
|                     select.select([], [s], []) | ||||
|                 else: | ||||
|                     raise | ||||
|         s.close() | ||||
|         if test_support.verbose: | ||||
|             sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) | ||||
| 
 | ||||
|     def testFetchServerCert(self): | ||||
| 
 | ||||
|         pem = ssl.get_server_certificate(("svn.python.org", 443)) | ||||
|  | @ -176,6 +202,18 @@ def __init__(self, server, connsock): | |||
|                 threading.Thread.__init__(self) | ||||
|                 self.setDaemon(True) | ||||
| 
 | ||||
|             def show_conn_details(self): | ||||
|                 if self.server.certreqs == ssl.CERT_REQUIRED: | ||||
|                     cert = self.sslconn.getpeercert() | ||||
|                     if test_support.verbose and self.server.chatty: | ||||
|                         sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") | ||||
|                     cert_binary = self.sslconn.getpeercert(True) | ||||
|                     if test_support.verbose and self.server.chatty: | ||||
|                         sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") | ||||
|                 cipher = self.sslconn.cipher() | ||||
|                 if test_support.verbose and self.server.chatty: | ||||
|                     sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") | ||||
| 
 | ||||
|             def wrap_conn (self): | ||||
|                 try: | ||||
|                     self.sslconn = ssl.wrap_socket(self.sock, server_side=True, | ||||
|  | @ -187,6 +225,7 @@ def wrap_conn (self): | |||
|                     if self.server.chatty: | ||||
|                         handle_error("\n server:  bad connection attempt from " + | ||||
|                                      str(self.sock.getpeername()) + ":\n") | ||||
|                     self.close() | ||||
|                     if not self.server.expect_bad_connects: | ||||
|                         # here, we want to stop the server, because this shouldn't | ||||
|                         # happen in the context of our test case | ||||
|  | @ -197,16 +236,6 @@ def wrap_conn (self): | |||
|                     return False | ||||
| 
 | ||||
|                 else: | ||||
|                     if self.server.certreqs == ssl.CERT_REQUIRED: | ||||
|                         cert = self.sslconn.getpeercert() | ||||
|                         if test_support.verbose and self.server.chatty: | ||||
|                             sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") | ||||
|                         cert_binary = self.sslconn.getpeercert(True) | ||||
|                         if test_support.verbose and self.server.chatty: | ||||
|                             sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") | ||||
|                     cipher = self.sslconn.cipher() | ||||
|                     if test_support.verbose and self.server.chatty: | ||||
|                         sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") | ||||
|                     return True | ||||
| 
 | ||||
|             def read(self): | ||||
|  | @ -225,13 +254,16 @@ def close(self): | |||
|                 if self.sslconn: | ||||
|                     self.sslconn.close() | ||||
|                 else: | ||||
|                     self.sock.close() | ||||
|                     self.sock._sock.close() | ||||
| 
 | ||||
|             def run (self): | ||||
|                 self.running = True | ||||
|                 if not self.server.starttls_server: | ||||
|                     if not self.wrap_conn(): | ||||
|                     if isinstance(self.sock, ssl.SSLSocket): | ||||
|                         self.sslconn = self.sock | ||||
|                     elif not self.wrap_conn(): | ||||
|                         return | ||||
|                     self.show_conn_details() | ||||
|                 while self.running: | ||||
|                     try: | ||||
|                         msg = self.read() | ||||
|  | @ -270,7 +302,9 @@ def run (self): | |||
| 
 | ||||
|         def __init__(self, certificate, ssl_version=None, | ||||
|                      certreqs=None, cacerts=None, expect_bad_connects=False, | ||||
|                      chatty=True, connectionchatty=False, starttls_server=False): | ||||
|                      chatty=True, connectionchatty=False, starttls_server=False, | ||||
|                      wrap_accepting_socket=False): | ||||
| 
 | ||||
|             if ssl_version is None: | ||||
|                 ssl_version = ssl.PROTOCOL_TLSv1 | ||||
|             if certreqs is None: | ||||
|  | @ -284,8 +318,16 @@ def __init__(self, certificate, ssl_version=None, | |||
|             self.connectionchatty = connectionchatty | ||||
|             self.starttls_server = starttls_server | ||||
|             self.sock = socket.socket() | ||||
|             self.port = test_support.bind_port(self.sock) | ||||
|             self.flag = None | ||||
|             if wrap_accepting_socket: | ||||
|                 self.sock = ssl.wrap_socket(self.sock, server_side=True, | ||||
|                                             certfile=self.certificate, | ||||
|                                             cert_reqs = self.certreqs, | ||||
|                                             ca_certs = self.cacerts, | ||||
|                                             ssl_version = self.protocol) | ||||
|                 if test_support.verbose and self.chatty: | ||||
|                     sys.stdout.write(' server:  wrapped server socket as %s\n' % str(self.sock)) | ||||
|             self.port = test_support.bind_port(self.sock) | ||||
|             self.active = False | ||||
|             threading.Thread.__init__(self) | ||||
|             self.setDaemon(False) | ||||
|  | @ -316,13 +358,86 @@ def run (self): | |||
|                 except: | ||||
|                     if self.chatty: | ||||
|                         handle_error("Test server failure:\n") | ||||
|             self.sock.close() | ||||
| 
 | ||||
|         def stop (self): | ||||
|             self.active = False | ||||
|             self.sock.close() | ||||
| 
 | ||||
|     class AsyncoreEchoServer(threading.Thread): | ||||
| 
 | ||||
|     class AsyncoreHTTPSServer(threading.Thread): | ||||
|         class EchoServer (asyncore.dispatcher): | ||||
| 
 | ||||
|             class ConnectionHandler (asyncore.dispatcher_with_send): | ||||
| 
 | ||||
|                 def __init__(self, conn, certfile): | ||||
|                     asyncore.dispatcher_with_send.__init__(self, conn) | ||||
|                     self.socket = ssl.wrap_socket(conn, server_side=True, | ||||
|                                                   certfile=certfile, | ||||
|                                                   do_handshake_on_connect=True) | ||||
| 
 | ||||
|                 def readable(self): | ||||
|                     if isinstance(self.socket, ssl.SSLSocket): | ||||
|                         while self.socket.pending() > 0: | ||||
|                             self.handle_read_event() | ||||
|                     return True | ||||
| 
 | ||||
|                 def handle_read(self): | ||||
|                     data = self.recv(1024) | ||||
|                     self.send(data.lower()) | ||||
| 
 | ||||
|                 def handle_close(self): | ||||
|                     if test_support.verbose: | ||||
|                         sys.stdout.write(" server:  closed connection %s\n" % self.socket) | ||||
| 
 | ||||
|                 def handle_error(self): | ||||
|                     raise | ||||
| 
 | ||||
|             def __init__(self, certfile): | ||||
|                 self.certfile = certfile | ||||
|                 asyncore.dispatcher.__init__(self) | ||||
|                 self.create_socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
|                 self.port = test_support.bind_port(self.socket) | ||||
|                 self.listen(5) | ||||
| 
 | ||||
|             def handle_accept(self): | ||||
|                 sock_obj, addr = self.accept() | ||||
|                 if test_support.verbose: | ||||
|                     sys.stdout.write(" server:  new connection from %s:%s\n" %addr) | ||||
|                 self.ConnectionHandler(sock_obj, self.certfile) | ||||
| 
 | ||||
|             def handle_error(self): | ||||
|                 raise | ||||
| 
 | ||||
|         def __init__(self, certfile): | ||||
|             self.flag = None | ||||
|             self.active = False | ||||
|             self.server = self.EchoServer(certfile) | ||||
|             self.port = self.server.port | ||||
|             threading.Thread.__init__(self) | ||||
|             self.setDaemon(True) | ||||
| 
 | ||||
|         def __str__(self): | ||||
|             return "<%s %s>" % (self.__class__.__name__, self.server) | ||||
| 
 | ||||
|         def start (self, flag=None): | ||||
|             self.flag = flag | ||||
|             threading.Thread.start(self) | ||||
| 
 | ||||
|         def run (self): | ||||
|             self.active = True | ||||
|             if self.flag: | ||||
|                 self.flag.set() | ||||
|             while self.active: | ||||
|                 try: | ||||
|                     asyncore.loop(1) | ||||
|                 except: | ||||
|                     pass | ||||
| 
 | ||||
|         def stop (self): | ||||
|             self.active = False | ||||
|             self.server.close() | ||||
| 
 | ||||
|     class SocketServerHTTPSServer(threading.Thread): | ||||
| 
 | ||||
|         class HTTPSServer(HTTPServer): | ||||
| 
 | ||||
|  | @ -335,6 +450,12 @@ def __init__(self, server_address, RequestHandlerClass, certfile): | |||
|                 self.active_lock = threading.Lock() | ||||
|                 self.allow_reuse_address = True | ||||
| 
 | ||||
|             def __str__(self): | ||||
|                 return ('<%s %s:%s>' % | ||||
|                         (self.__class__.__name__, | ||||
|                          self.server_name, | ||||
|                          self.server_port)) | ||||
| 
 | ||||
|             def get_request (self): | ||||
|                 # override this to wrap socket with SSL | ||||
|                 sock, addr = self.socket.accept() | ||||
|  | @ -421,8 +542,8 @@ def log_message(self, format, *args): | |||
|                 # we override this to suppress logging unless "verbose" | ||||
| 
 | ||||
|                 if test_support.verbose: | ||||
|                     sys.stdout.write(" server (%s, %d, %s):\n   [%s] %s\n" % | ||||
|                                      (self.server.server_name, | ||||
|                     sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" % | ||||
|                                      (self.server.server_address, | ||||
|                                       self.server.server_port, | ||||
|                                       self.request.cipher(), | ||||
|                                       self.log_date_time_string(), | ||||
|  | @ -440,9 +561,7 @@ def __init__(self, certfile): | |||
|             self.setDaemon(True) | ||||
| 
 | ||||
|         def __str__(self): | ||||
|             return '<%s %s:%d>' % (self.__class__.__name__, | ||||
|                                    self.server.server_name, | ||||
|                                    self.server.server_port) | ||||
|             return "<%s %s>" % (self.__class__.__name__, self.server) | ||||
| 
 | ||||
|         def start (self, flag=None): | ||||
|             self.flag = flag | ||||
|  | @ -487,14 +606,16 @@ def badCertTest (certfile): | |||
| 
 | ||||
|     def serverParamsTest (certfile, protocol, certreqs, cacertsfile, | ||||
|                           client_certfile, client_protocol=None, indata="FOO\n", | ||||
|                           chatty=True, connectionchatty=False): | ||||
|                           chatty=True, connectionchatty=False, | ||||
|                           wrap_accepting_socket=False): | ||||
| 
 | ||||
|         server = ThreadedEchoServer(certfile, | ||||
|                                     certreqs=certreqs, | ||||
|                                     ssl_version=protocol, | ||||
|                                     cacerts=cacertsfile, | ||||
|                                     chatty=chatty, | ||||
|                                     connectionchatty=connectionchatty) | ||||
|                                     connectionchatty=connectionchatty, | ||||
|                                     wrap_accepting_socket=wrap_accepting_socket) | ||||
|         flag = threading.Event() | ||||
|         server.start(flag) | ||||
|         # wait for it to start | ||||
|  | @ -572,7 +693,7 @@ def tryProtocolCombo (server_protocol, | |||
|                        ssl.get_protocol_name(server_protocol))) | ||||
| 
 | ||||
| 
 | ||||
|     class ConnectedTests(unittest.TestCase): | ||||
|     class ThreadedTests(unittest.TestCase): | ||||
| 
 | ||||
|         def testRudeShutdown(self): | ||||
| 
 | ||||
|  | @ -600,7 +721,7 @@ def connector(): | |||
|                 listener_gone.wait() | ||||
|                 try: | ||||
|                     ssl_sock = ssl.wrap_socket(s) | ||||
|                 except socket.sslerror: | ||||
|                 except IOError: | ||||
|                     pass | ||||
|                 else: | ||||
|                     raise test_support.TestFailed( | ||||
|  | @ -680,6 +801,9 @@ def testNULLcert(self): | |||
|         def testMalformedCert(self): | ||||
|             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, | ||||
|                                      "badcert.pem")) | ||||
|         def testWrongCert(self): | ||||
|             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, | ||||
|                                      "wrongcert.pem")) | ||||
|         def testMalformedKey(self): | ||||
|             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, | ||||
|                                      "badkey.pem")) | ||||
|  | @ -796,9 +920,9 @@ def testSTARTTLS (self): | |||
|                 server.stop() | ||||
|                 server.join() | ||||
| 
 | ||||
|         def testAsyncore(self): | ||||
|         def testSocketServer(self): | ||||
| 
 | ||||
|             server = AsyncoreHTTPSServer(CERTFILE) | ||||
|             server = SocketServerHTTPSServer(CERTFILE) | ||||
|             flag = threading.Event() | ||||
|             server.start(flag) | ||||
|             # wait for it to start | ||||
|  | @ -810,8 +934,8 @@ def testAsyncore(self): | |||
|                 d1 = open(CERTFILE, 'rb').read() | ||||
|                 d2 = '' | ||||
|                 # now fetch the same data from the HTTPS server | ||||
|                 url = 'https://%s:%d/%s' % ( | ||||
|                     HOST, server.port, os.path.split(CERTFILE)[1]) | ||||
|                 url = 'https://127.0.0.1:%d/%s' % ( | ||||
|                     server.port, os.path.split(CERTFILE)[1]) | ||||
|                 f = urllib.urlopen(url) | ||||
|                 dlen = f.info().getheader("content-length") | ||||
|                 if dlen and (int(dlen) > 0): | ||||
|  | @ -834,6 +958,58 @@ def testAsyncore(self): | |||
|                 server.stop() | ||||
|                 server.join() | ||||
| 
 | ||||
|         def testWrappedAccept (self): | ||||
| 
 | ||||
|             if test_support.verbose: | ||||
|                 sys.stdout.write("\n") | ||||
|             serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, | ||||
|                              CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, | ||||
|                              chatty=True, connectionchatty=True, | ||||
|                              wrap_accepting_socket=True) | ||||
| 
 | ||||
| 
 | ||||
|         def testAsyncoreServer (self): | ||||
| 
 | ||||
|             indata = "TEST MESSAGE of mixed case\n" | ||||
| 
 | ||||
|             if test_support.verbose: | ||||
|                 sys.stdout.write("\n") | ||||
|             server = AsyncoreEchoServer(CERTFILE) | ||||
|             flag = threading.Event() | ||||
|             server.start(flag) | ||||
|             # wait for it to start | ||||
|             flag.wait() | ||||
|             # try to connect | ||||
|             try: | ||||
|                 try: | ||||
|                     s = ssl.wrap_socket(socket.socket()) | ||||
|                     s.connect(('127.0.0.1', server.port)) | ||||
|                 except ssl.SSLError, x: | ||||
|                     raise test_support.TestFailed("Unexpected SSL error:  " + str(x)) | ||||
|                 except Exception, x: | ||||
|                     raise test_support.TestFailed("Unexpected exception:  " + str(x)) | ||||
|                 else: | ||||
|                     if test_support.verbose: | ||||
|                         sys.stdout.write( | ||||
|                             " client:  sending %s...\n" % (repr(indata))) | ||||
|                     s.write(indata) | ||||
|                     outdata = s.read() | ||||
|                     if test_support.verbose: | ||||
|                         sys.stdout.write(" client:  read %s\n" % repr(outdata)) | ||||
|                     if outdata != indata.lower(): | ||||
|                         raise test_support.TestFailed( | ||||
|                             "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" | ||||
|                             % (outdata[:min(len(outdata),20)], len(outdata), | ||||
|                                indata[:min(len(indata),20)].lower(), len(indata))) | ||||
|                     s.write("over\n") | ||||
|                     if test_support.verbose: | ||||
|                         sys.stdout.write(" client:  closing connection.\n") | ||||
|                     s.close() | ||||
|             finally: | ||||
|                 server.stop() | ||||
|                 # wait for server thread to end | ||||
|                 server.join() | ||||
| 
 | ||||
| 
 | ||||
| def test_main(verbose=False): | ||||
|     if skip_expected: | ||||
|  | @ -850,15 +1026,19 @@ def test_main(verbose=False): | |||
|         not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)): | ||||
|         raise test_support.TestFailed("Can't read certificate files!") | ||||
| 
 | ||||
|     TESTPORT = test_support.find_unused_port() | ||||
|     if not TESTPORT: | ||||
|         raise test_support.TestFailed("Can't find open port to test servers on!") | ||||
| 
 | ||||
|     tests = [BasicTests] | ||||
| 
 | ||||
|     if test_support.is_resource_enabled('network'): | ||||
|         tests.append(NetworkTests) | ||||
|         tests.append(NetworkedTests) | ||||
| 
 | ||||
|     if _have_threads: | ||||
|         thread_info = test_support.threading_setup() | ||||
|         if thread_info and test_support.is_resource_enabled('network'): | ||||
|             tests.append(ConnectedTests) | ||||
|             tests.append(ThreadedTests) | ||||
| 
 | ||||
|     test_support.run_unittest(*tests) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										32
									
								
								Lib/test/wrongcert.pem
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								Lib/test/wrongcert.pem
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,32 @@ | |||
| -----BEGIN RSA PRIVATE KEY----- | ||||
| MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH | ||||
| FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T | ||||
| f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB | ||||
| AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq | ||||
| 1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW | ||||
| 7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg | ||||
| SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe | ||||
| 19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg | ||||
| ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666 | ||||
| lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs | ||||
| ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv | ||||
| frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk | ||||
| 2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc= | ||||
| -----END RSA PRIVATE KEY----- | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV | ||||
| BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX | ||||
| aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF | ||||
| MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 | ||||
| ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB | ||||
| gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6 | ||||
| +bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK | ||||
| 24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G | ||||
| A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9 | ||||
| OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt | ||||
| U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ | ||||
| fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E | ||||
| usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+ | ||||
| 43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw | ||||
| eSHj5jpC8iZKjCHBn+mAi4cQ514= | ||||
| -----END CERTIFICATE----- | ||||
							
								
								
									
										197
									
								
								Modules/_ssl.c
									
										
									
									
									
								
							
							
						
						
									
										197
									
								
								Modules/_ssl.c
									
										
									
									
									
								
							|  | @ -2,14 +2,15 @@ | |||
| 
 | ||||
|    SSL support based on patches by Brian E Gallew and Laszlo Kovacs. | ||||
|    Re-worked a bit by Bill Janssen to add server-side support and | ||||
|    certificate decoding. | ||||
|    certificate decoding.  Chris Stawarz contributed some non-blocking | ||||
|    patches. | ||||
| 
 | ||||
|    This module is imported by ssl.py. It should *not* be used | ||||
|    directly. | ||||
| 
 | ||||
|    XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE? | ||||
| 
 | ||||
|    XXX what about SSL_MODE_AUTO_RETRY | ||||
|    XXX what about SSL_MODE_AUTO_RETRY? | ||||
| */ | ||||
| 
 | ||||
| #include "Python.h" | ||||
|  | @ -265,8 +266,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, | |||
| 	PySSLObject *self; | ||||
| 	char *errstr = NULL; | ||||
| 	int ret; | ||||
| 	int err; | ||||
| 	int sockstate; | ||||
| 	int verification_mode; | ||||
| 
 | ||||
| 	self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */ | ||||
|  | @ -388,57 +387,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, | |||
| 		SSL_set_accept_state(self->ssl); | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 
 | ||||
| 	/* Actually negotiate SSL connection */ | ||||
| 	/* XXX If SSL_connect() returns 0, it's also a failure. */ | ||||
| 	sockstate = 0; | ||||
| 	do { | ||||
| 		PySSL_BEGIN_ALLOW_THREADS | ||||
| 		if (socket_type == PY_SSL_CLIENT) | ||||
| 			ret = SSL_connect(self->ssl); | ||||
| 		else | ||||
| 			ret = SSL_accept(self->ssl); | ||||
| 		err = SSL_get_error(self->ssl, ret); | ||||
| 		PySSL_END_ALLOW_THREADS | ||||
| 		if(PyErr_CheckSignals()) { | ||||
| 			goto fail; | ||||
| 		} | ||||
| 		if (err == SSL_ERROR_WANT_READ) { | ||||
| 			sockstate = check_socket_and_wait_for_timeout(Sock, 0); | ||||
| 		} else if (err == SSL_ERROR_WANT_WRITE) { | ||||
| 			sockstate = check_socket_and_wait_for_timeout(Sock, 1); | ||||
| 		} else { | ||||
| 			sockstate = SOCKET_OPERATION_OK; | ||||
| 		} | ||||
| 		if (sockstate == SOCKET_HAS_TIMED_OUT) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 				ERRSTR("The connect operation timed out")); | ||||
| 			goto fail; | ||||
| 		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 				ERRSTR("Underlying socket has been closed.")); | ||||
| 			goto fail; | ||||
| 		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 			  ERRSTR("Underlying socket too large for select().")); | ||||
| 			goto fail; | ||||
| 		} else if (sockstate == SOCKET_IS_NONBLOCKING) { | ||||
| 			break; | ||||
| 		} | ||||
| 	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); | ||||
| 	if (ret < 1) { | ||||
| 		PySSL_SetError(self, ret, __FILE__, __LINE__); | ||||
| 		goto fail; | ||||
| 	} | ||||
| 	self->ssl->debug = 1; | ||||
| 
 | ||||
| 	PySSL_BEGIN_ALLOW_THREADS | ||||
| 	if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) { | ||||
| 		X509_NAME_oneline(X509_get_subject_name(self->peer_cert), | ||||
| 				  self->server, X509_NAME_MAXLEN); | ||||
| 		X509_NAME_oneline(X509_get_issuer_name(self->peer_cert), | ||||
| 				  self->issuer, X509_NAME_MAXLEN); | ||||
| 	} | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 	self->Socket = Sock; | ||||
| 	Py_INCREF(self->Socket); | ||||
| 	return self; | ||||
|  | @ -488,6 +436,65 @@ PyDoc_STRVAR(ssl_doc, | |||
| 
 | ||||
| /* SSL object methods */ | ||||
| 
 | ||||
| static PyObject *PySSL_SSLdo_handshake(PySSLObject *self) | ||||
| { | ||||
| 	int ret; | ||||
| 	int err; | ||||
| 	int sockstate; | ||||
| 
 | ||||
| 	/* Actually negotiate SSL connection */ | ||||
| 	/* XXX If SSL_do_handshake() returns 0, it's also a failure. */ | ||||
| 	sockstate = 0; | ||||
| 	do { | ||||
| 		PySSL_BEGIN_ALLOW_THREADS | ||||
| 		ret = SSL_do_handshake(self->ssl); | ||||
| 		err = SSL_get_error(self->ssl, ret); | ||||
| 		PySSL_END_ALLOW_THREADS | ||||
| 		if(PyErr_CheckSignals()) { | ||||
| 			return NULL; | ||||
| 		} | ||||
| 		if (err == SSL_ERROR_WANT_READ) { | ||||
| 			sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); | ||||
| 		} else if (err == SSL_ERROR_WANT_WRITE) { | ||||
| 			sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); | ||||
| 		} else { | ||||
| 			sockstate = SOCKET_OPERATION_OK; | ||||
| 		} | ||||
| 		if (sockstate == SOCKET_HAS_TIMED_OUT) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 				ERRSTR("The handshake operation timed out")); | ||||
| 			return NULL; | ||||
| 		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 				ERRSTR("Underlying socket has been closed.")); | ||||
| 			return NULL; | ||||
| 		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { | ||||
| 			PyErr_SetString(PySSLErrorObject, | ||||
| 			  ERRSTR("Underlying socket too large for select().")); | ||||
| 			return NULL; | ||||
| 		} else if (sockstate == SOCKET_IS_NONBLOCKING) { | ||||
| 			break; | ||||
| 		} | ||||
| 	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); | ||||
| 	if (ret < 1) | ||||
| 		return PySSL_SetError(self, ret, __FILE__, __LINE__); | ||||
| 	self->ssl->debug = 1; | ||||
| 
 | ||||
| 	if (self->peer_cert) | ||||
| 		X509_free (self->peer_cert); | ||||
| 	PySSL_BEGIN_ALLOW_THREADS | ||||
| 	if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) { | ||||
| 		X509_NAME_oneline(X509_get_subject_name(self->peer_cert), | ||||
| 				  self->server, X509_NAME_MAXLEN); | ||||
| 		X509_NAME_oneline(X509_get_issuer_name(self->peer_cert), | ||||
| 				  self->issuer, X509_NAME_MAXLEN); | ||||
| 	} | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 
 | ||||
| 	Py_INCREF(Py_None); | ||||
| 	return Py_None; | ||||
| } | ||||
| 
 | ||||
| static PyObject * | ||||
| PySSL_server(PySSLObject *self) | ||||
| { | ||||
|  | @ -1127,7 +1134,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) | |||
| 		rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv); | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 
 | ||||
| #ifdef HAVE_POLL | ||||
| normal_return: | ||||
| #endif | ||||
| 	/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
 | ||||
| 	   (when we are able to write or when there's something to read) */ | ||||
| 	return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK; | ||||
|  | @ -1140,10 +1149,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) | |||
| 	int count; | ||||
| 	int sockstate; | ||||
| 	int err; | ||||
|         int nonblocking; | ||||
| 
 | ||||
| 	if (!PyArg_ParseTuple(args, "s#:write", &data, &count)) | ||||
| 		return NULL; | ||||
| 
 | ||||
|         /* just in case the blocking state of the socket has been changed */ | ||||
| 	nonblocking = (self->Socket->sock_timeout >= 0.0); | ||||
|         BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); | ||||
|         BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); | ||||
| 
 | ||||
| 	sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); | ||||
| 	if (sockstate == SOCKET_HAS_TIMED_OUT) { | ||||
| 		PyErr_SetString(PySSLErrorObject, | ||||
|  | @ -1200,6 +1215,25 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc, | |||
| Writes the string s into the SSL object.  Returns the number\n\ | ||||
| of bytes written."); | ||||
| 
 | ||||
| static PyObject *PySSL_SSLpending(PySSLObject *self) | ||||
| { | ||||
| 	int count = 0; | ||||
| 
 | ||||
| 	PySSL_BEGIN_ALLOW_THREADS | ||||
| 	count = SSL_pending(self->ssl); | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 	if (count < 0) | ||||
| 		return PySSL_SetError(self, count, __FILE__, __LINE__); | ||||
| 	else | ||||
| 		return PyInt_FromLong(count); | ||||
| } | ||||
| 
 | ||||
| PyDoc_STRVAR(PySSL_SSLpending_doc, | ||||
| "pending() -> count\n\
 | ||||
| \n\ | ||||
| Returns the number of already decrypted bytes available for read,\n\ | ||||
| pending on the connection.\n"); | ||||
| 
 | ||||
| static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) | ||||
| { | ||||
| 	PyObject *buf; | ||||
|  | @ -1207,6 +1241,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) | |||
| 	int len = 1024; | ||||
| 	int sockstate; | ||||
| 	int err; | ||||
|         int nonblocking; | ||||
| 
 | ||||
| 	if (!PyArg_ParseTuple(args, "|i:read", &len)) | ||||
| 		return NULL; | ||||
|  | @ -1214,6 +1249,11 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) | |||
| 	if (!(buf = PyString_FromStringAndSize((char *) 0, len))) | ||||
| 		return NULL; | ||||
| 
 | ||||
|         /* just in case the blocking state of the socket has been changed */ | ||||
| 	nonblocking = (self->Socket->sock_timeout >= 0.0); | ||||
|         BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); | ||||
|         BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); | ||||
| 
 | ||||
| 	/* first check if there are bytes ready to be read */ | ||||
| 	PySSL_BEGIN_ALLOW_THREADS | ||||
| 	count = SSL_pending(self->ssl); | ||||
|  | @ -1232,11 +1272,20 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) | |||
| 			Py_DECREF(buf); | ||||
| 			return NULL; | ||||
| 		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { | ||||
| 			if (SSL_get_shutdown(self->ssl) != | ||||
| 			    SSL_RECEIVED_SHUTDOWN) | ||||
| 			{ | ||||
|                             Py_DECREF(buf); | ||||
|                             PyErr_SetString(PySSLErrorObject, | ||||
|                               "Socket closed without SSL shutdown handshake"); | ||||
| 				return NULL; | ||||
| 			} else { | ||||
| 				/* should contain a zero-length string */ | ||||
| 				_PyString_Resize(&buf, 0); | ||||
| 				return buf; | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	do { | ||||
| 		err = 0; | ||||
| 		PySSL_BEGIN_ALLOW_THREADS | ||||
|  | @ -1285,16 +1334,54 @@ PyDoc_STRVAR(PySSL_SSLread_doc, | |||
| \n\ | ||||
| Read up to len bytes from the SSL socket."); | ||||
| 
 | ||||
| static PyObject *PySSL_SSLshutdown(PySSLObject *self) | ||||
| { | ||||
| 	int err; | ||||
| 
 | ||||
| 	/* Guard against closed socket */ | ||||
| 	if (self->Socket->sock_fd < 0) { | ||||
| 		PyErr_SetString(PySSLErrorObject, | ||||
| 				"Underlying socket has been closed."); | ||||
| 		return NULL; | ||||
| 	} | ||||
| 
 | ||||
| 	PySSL_BEGIN_ALLOW_THREADS | ||||
| 	err = SSL_shutdown(self->ssl); | ||||
| 	if (err == 0) { | ||||
| 		/* we need to call it again to finish the shutdown */ | ||||
| 		err = SSL_shutdown(self->ssl); | ||||
| 	} | ||||
| 	PySSL_END_ALLOW_THREADS | ||||
| 
 | ||||
| 	if (err < 0) | ||||
| 		return PySSL_SetError(self, err, __FILE__, __LINE__); | ||||
| 	else { | ||||
| 		Py_INCREF(self->Socket); | ||||
| 		return (PyObject *) (self->Socket); | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| PyDoc_STRVAR(PySSL_SSLshutdown_doc, | ||||
| "shutdown(s) -> socket\n\
 | ||||
| \n\ | ||||
| Does the SSL shutdown handshake with the remote end, and returns\n\ | ||||
| the underlying socket object."); | ||||
| 
 | ||||
| static PyMethodDef PySSLMethods[] = { | ||||
| 	{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS}, | ||||
| 	{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, | ||||
| 	 PySSL_SSLwrite_doc}, | ||||
| 	{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS, | ||||
| 	 PySSL_SSLread_doc}, | ||||
| 	{"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS, | ||||
| 	 PySSL_SSLpending_doc}, | ||||
| 	{"server", (PyCFunction)PySSL_server, METH_NOARGS}, | ||||
| 	{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS}, | ||||
| 	{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, | ||||
| 	 PySSL_peercert_doc}, | ||||
| 	{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, | ||||
| 	{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, | ||||
|          PySSL_SSLshutdown_doc}, | ||||
| 	{NULL, NULL} | ||||
| }; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Bill Janssen
						Bill Janssen