mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	get SSL support to work again
This commit is contained in:
		
							parent
							
								
									f83088aefe
								
							
						
					
					
						commit
						6e027dba93
					
				
					 3 changed files with 536 additions and 570 deletions
				
			
		
							
								
								
									
										461
									
								
								Lib/ssl.py
									
										
									
									
									
								
							
							
						
						
									
										461
									
								
								Lib/ssl.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -1,8 +1,6 @@
 | 
			
		|||
# Wrapper module for _ssl, providing some additional facilities
 | 
			
		||||
# implemented in Python.  Written by Bill Janssen.
 | 
			
		||||
 | 
			
		||||
raise ImportError("ssl.py is temporarily out of order")
 | 
			
		||||
 | 
			
		||||
"""\
 | 
			
		||||
This module provides some more Pythonic support for SSL.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -76,9 +74,11 @@
 | 
			
		|||
     SSL_ERROR_EOF, \
 | 
			
		||||
     SSL_ERROR_INVALID_ERROR_CODE
 | 
			
		||||
 | 
			
		||||
from socket import socket
 | 
			
		||||
from socket import socket, AF_INET, SOCK_STREAM, error
 | 
			
		||||
from socket import getnameinfo as _getnameinfo
 | 
			
		||||
from socket import error as socket_error
 | 
			
		||||
import base64        # for DER-to-PEM translation
 | 
			
		||||
_can_dup_socket = hasattr(socket, "dup")
 | 
			
		||||
 | 
			
		||||
class SSLSocket (socket):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -86,10 +86,38 @@ class SSLSocket (socket):
 | 
			
		|||
    the underlying OS socket in an SSL context when necessary, and
 | 
			
		||||
    provides read and write methods over that channel."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, sock, keyfile=None, certfile=None,
 | 
			
		||||
    def __init__(self, sock=None, keyfile=None, certfile=None,
 | 
			
		||||
                 server_side=False, cert_reqs=CERT_NONE,
 | 
			
		||||
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None):
 | 
			
		||||
        socket.__init__(self, _sock=sock._sock)
 | 
			
		||||
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
 | 
			
		||||
                 do_handshake_on_connect=True,
 | 
			
		||||
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
 | 
			
		||||
                 suppress_ragged_eofs=True):
 | 
			
		||||
 | 
			
		||||
        self._base = None
 | 
			
		||||
 | 
			
		||||
        if sock is not None:
 | 
			
		||||
            # copied this code from socket.accept()
 | 
			
		||||
            fd = sock.fileno()
 | 
			
		||||
            nfd = fd
 | 
			
		||||
            if _can_dup_socket:
 | 
			
		||||
                nfd = os.dup(fd)
 | 
			
		||||
            try:
 | 
			
		||||
                wrapper = socket.__init__(self, family=sock.family, type=sock.type, proto=sock.proto, fileno=nfd)
 | 
			
		||||
            except:
 | 
			
		||||
                if nfd != fd:
 | 
			
		||||
                    os.close(nfd)
 | 
			
		||||
            else:
 | 
			
		||||
                if fd != nfd:
 | 
			
		||||
                    sock.close()
 | 
			
		||||
                    sock = None
 | 
			
		||||
 | 
			
		||||
        elif fileno is not None:
 | 
			
		||||
            socket.__init__(self, fileno=fileno)
 | 
			
		||||
        else:
 | 
			
		||||
            socket.__init__(self, family=family, type=type, proto=proto)
 | 
			
		||||
 | 
			
		||||
        self._closed = False
 | 
			
		||||
 | 
			
		||||
        if certfile and not keyfile:
 | 
			
		||||
            keyfile = certfile
 | 
			
		||||
        # see if it's connected
 | 
			
		||||
| 
						 | 
				
			
			@ -100,27 +128,52 @@ def __init__(self, sock, keyfile=None, certfile=None,
 | 
			
		|||
            self._sslobj = None
 | 
			
		||||
        else:
 | 
			
		||||
            # yes, create the SSL object
 | 
			
		||||
            self._sslobj = _ssl.sslwrap(self._sock, server_side,
 | 
			
		||||
                                        keyfile, certfile,
 | 
			
		||||
                                        cert_reqs, ssl_version, ca_certs)
 | 
			
		||||
            try:
 | 
			
		||||
                self._sslobj = _ssl.sslwrap(self, server_side,
 | 
			
		||||
                                            keyfile, certfile,
 | 
			
		||||
                                            cert_reqs, ssl_version, ca_certs)
 | 
			
		||||
                if do_handshake_on_connect:
 | 
			
		||||
                    self.do_handshake()
 | 
			
		||||
            except socket_error as x:
 | 
			
		||||
                self.close()
 | 
			
		||||
                raise x
 | 
			
		||||
 | 
			
		||||
        self._base = sock
 | 
			
		||||
        self.keyfile = keyfile
 | 
			
		||||
        self.certfile = certfile
 | 
			
		||||
        self.cert_reqs = cert_reqs
 | 
			
		||||
        self.ssl_version = ssl_version
 | 
			
		||||
        self.ca_certs = ca_certs
 | 
			
		||||
        self.do_handshake_on_connect = do_handshake_on_connect
 | 
			
		||||
        self.suppress_ragged_eofs = suppress_ragged_eofs
 | 
			
		||||
 | 
			
		||||
    def read(self, len=1024):
 | 
			
		||||
    def _checkClosed(self, msg=None):
 | 
			
		||||
        # raise an exception here if you wish to check for spurious closes
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def read(self, len=1024, buffer=None):
 | 
			
		||||
 | 
			
		||||
        """Read up to LEN bytes and return them.
 | 
			
		||||
        Return zero-length string on EOF."""
 | 
			
		||||
 | 
			
		||||
        return self._sslobj.read(len)
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        try:
 | 
			
		||||
            if buffer:
 | 
			
		||||
                return self._sslobj.read(buffer, len)
 | 
			
		||||
            else:
 | 
			
		||||
                return self._sslobj.read(len)
 | 
			
		||||
        except SSLError as x:
 | 
			
		||||
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
 | 
			
		||||
                return b''
 | 
			
		||||
            else:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    def write(self, data):
 | 
			
		||||
 | 
			
		||||
        """Write DATA to the underlying SSL channel.  Returns
 | 
			
		||||
        number of bytes of DATA actually transmitted."""
 | 
			
		||||
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        return self._sslobj.write(data)
 | 
			
		||||
 | 
			
		||||
    def getpeercert(self, binary_form=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -130,26 +183,42 @@ def getpeercert(self, binary_form=False):
 | 
			
		|||
        Return None if no certificate was provided, {} if a
 | 
			
		||||
        certificate was provided, but not validated."""
 | 
			
		||||
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        return self._sslobj.peer_certificate(binary_form)
 | 
			
		||||
 | 
			
		||||
    def cipher (self):
 | 
			
		||||
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if not self._sslobj:
 | 
			
		||||
            return None
 | 
			
		||||
        else:
 | 
			
		||||
            return self._sslobj.cipher()
 | 
			
		||||
 | 
			
		||||
    def send (self, data, flags=0):
 | 
			
		||||
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            if flags != 0:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "non-zero flags not allowed in calls to send() on %s" %
 | 
			
		||||
                    self.__class__)
 | 
			
		||||
            return self._sslobj.write(data)
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    v = self._sslobj.write(data)
 | 
			
		||||
                except SSLError as x:
 | 
			
		||||
                    if x.args[0] == SSL_ERROR_WANT_READ:
 | 
			
		||||
                        return 0
 | 
			
		||||
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
 | 
			
		||||
                        return 0
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise
 | 
			
		||||
                else:
 | 
			
		||||
                    return v
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.send(self, data, flags)
 | 
			
		||||
 | 
			
		||||
    def send_to (self, data, addr, flags=0):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            raise ValueError("send_to not allowed on instances of %s" %
 | 
			
		||||
                             self.__class__)
 | 
			
		||||
| 
						 | 
				
			
			@ -157,39 +226,95 @@ def send_to (self, data, addr, flags=0):
 | 
			
		|||
            return socket.send_to(self, data, addr, flags)
 | 
			
		||||
 | 
			
		||||
    def sendall (self, data, flags=0):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            if flags != 0:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "non-zero flags not allowed in calls to sendall() on %s" %
 | 
			
		||||
                    self.__class__)
 | 
			
		||||
            return self._sslobj.write(data)
 | 
			
		||||
            amount = len(data)
 | 
			
		||||
            count = 0
 | 
			
		||||
            while (count < amount):
 | 
			
		||||
                v = self.send(data[count:])
 | 
			
		||||
                count += v
 | 
			
		||||
            return amount
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.sendall(self, data, flags)
 | 
			
		||||
 | 
			
		||||
    def recv (self, buflen=1024, flags=0):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            if flags != 0:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "non-zero flags not allowed in calls to sendall() on %s" %
 | 
			
		||||
                    "non-zero flags not allowed in calls to recv_into() on %s" %
 | 
			
		||||
                    self.__class__)
 | 
			
		||||
            return self._sslobj.read(data, buflen)
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    return self.read(buflen)
 | 
			
		||||
                except SSLError as x:
 | 
			
		||||
                    if x.args[0] == SSL_ERROR_WANT_READ:
 | 
			
		||||
                        continue
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise x
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.recv(self, buflen, flags)
 | 
			
		||||
 | 
			
		||||
    def recv_into (self, buffer, nbytes=None, flags=0):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if buffer and (nbytes is None):
 | 
			
		||||
            nbytes = len(buffer)
 | 
			
		||||
        elif nbytes is None:
 | 
			
		||||
            nbytes = 1024
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            if flags != 0:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "non-zero flags not allowed in calls to recv_into() on %s" %
 | 
			
		||||
                    self.__class__)
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    v = self.read(nbytes, buffer)
 | 
			
		||||
                    sys.stdout.flush()
 | 
			
		||||
                    return v
 | 
			
		||||
                except SSLError as x:
 | 
			
		||||
                    if x.args[0] == SSL_ERROR_WANT_READ:
 | 
			
		||||
                        continue
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise x
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.recv_into(self, buffer, nbytes, flags)
 | 
			
		||||
 | 
			
		||||
    def recv_from (self, addr, buflen=1024, flags=0):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            raise ValueError("recv_from not allowed on instances of %s" %
 | 
			
		||||
                             self.__class__)
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.recv_from(self, addr, buflen, flags)
 | 
			
		||||
 | 
			
		||||
    def shutdown(self, how):
 | 
			
		||||
    def pending (self):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            return self._sslobj.pending()
 | 
			
		||||
        else:
 | 
			
		||||
            return 0
 | 
			
		||||
 | 
			
		||||
    def shutdown (self, how):
 | 
			
		||||
        self._checkClosed()
 | 
			
		||||
        self._sslobj = None
 | 
			
		||||
        socket.shutdown(self, how)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
    def _real_close (self):
 | 
			
		||||
        self._sslobj = None
 | 
			
		||||
        socket.close(self)
 | 
			
		||||
        # self._closed = True
 | 
			
		||||
        if self._base:
 | 
			
		||||
            self._base.close()
 | 
			
		||||
        socket._real_close(self)
 | 
			
		||||
 | 
			
		||||
    def do_handshake (self):
 | 
			
		||||
 | 
			
		||||
        """Perform a TLS/SSL handshake."""
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            self._sslobj.do_handshake()
 | 
			
		||||
        except:
 | 
			
		||||
            self._sslobj = None
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def connect(self, addr):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -201,9 +326,11 @@ def connect(self, addr):
 | 
			
		|||
        if self._sslobj:
 | 
			
		||||
            raise ValueError("attempt to connect already-connected SSLSocket!")
 | 
			
		||||
        socket.connect(self, addr)
 | 
			
		||||
        self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
 | 
			
		||||
        self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
 | 
			
		||||
                                    self.cert_reqs, self.ssl_version,
 | 
			
		||||
                                    self.ca_certs)
 | 
			
		||||
        if self.do_handshake_on_connect:
 | 
			
		||||
            self.do_handshake()
 | 
			
		||||
 | 
			
		||||
    def accept(self):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -212,260 +339,24 @@ def accept(self):
 | 
			
		|||
        SSL channel, and the address of the remote client."""
 | 
			
		||||
 | 
			
		||||
        newsock, addr = socket.accept(self)
 | 
			
		||||
        return (SSLSocket(newsock, True, self.keyfile, self.certfile,
 | 
			
		||||
                          self.cert_reqs, self.ssl_version,
 | 
			
		||||
                          self.ca_certs), addr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def makefile(self, mode='r', bufsize=-1):
 | 
			
		||||
 | 
			
		||||
        """Ouch.  Need to make and return a file-like object that
 | 
			
		||||
        works with the SSL connection."""
 | 
			
		||||
 | 
			
		||||
        if self._sslobj:
 | 
			
		||||
            return SSLFileStream(self._sslobj, mode, bufsize)
 | 
			
		||||
        else:
 | 
			
		||||
            return socket.makefile(self, mode, bufsize)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SSLFileStream:
 | 
			
		||||
 | 
			
		||||
    """A class to simulate a file stream on top of a socket.
 | 
			
		||||
    Most of this is just lifted from the socket module, and
 | 
			
		||||
    adjusted to work with an SSL stream instead of a socket."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    default_bufsize = 8192
 | 
			
		||||
    name = "<SSL stream>"
 | 
			
		||||
 | 
			
		||||
    __slots__ = ["mode", "bufsize", "softspace",
 | 
			
		||||
                 # "closed" is a property, see below
 | 
			
		||||
                 "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
 | 
			
		||||
                 "_close", "_fileno"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
 | 
			
		||||
        self._sslobj = sslobj
 | 
			
		||||
        self.mode = mode # Not actually used in this version
 | 
			
		||||
        if bufsize < 0:
 | 
			
		||||
            bufsize = self.default_bufsize
 | 
			
		||||
        self.bufsize = bufsize
 | 
			
		||||
        self.softspace = False
 | 
			
		||||
        if bufsize == 0:
 | 
			
		||||
            self._rbufsize = 1
 | 
			
		||||
        elif bufsize == 1:
 | 
			
		||||
            self._rbufsize = self.default_bufsize
 | 
			
		||||
        else:
 | 
			
		||||
            self._rbufsize = bufsize
 | 
			
		||||
        self._wbufsize = bufsize
 | 
			
		||||
        self._rbuf = "" # A string
 | 
			
		||||
        self._wbuf = [] # A list of strings
 | 
			
		||||
        self._close = close
 | 
			
		||||
        self._fileno = -1
 | 
			
		||||
 | 
			
		||||
    def _getclosed(self):
 | 
			
		||||
        return self._sslobj is None
 | 
			
		||||
    closed = property(_getclosed, doc="True if the file is closed")
 | 
			
		||||
 | 
			
		||||
    def fileno(self):
 | 
			
		||||
        return self._fileno
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        try:
 | 
			
		||||
            if self._sslobj:
 | 
			
		||||
                self.flush()
 | 
			
		||||
        finally:
 | 
			
		||||
            if self._close and self._sslobj:
 | 
			
		||||
                self._sslobj.close()
 | 
			
		||||
            self._sslobj = None
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        try:
 | 
			
		||||
            self.close()
 | 
			
		||||
        except:
 | 
			
		||||
            # close() may fail if __init__ didn't complete
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    def flush(self):
 | 
			
		||||
        if self._wbuf:
 | 
			
		||||
            buffer = "".join(self._wbuf)
 | 
			
		||||
            self._wbuf = []
 | 
			
		||||
            count = 0
 | 
			
		||||
            while (count < len(buffer)):
 | 
			
		||||
                written = self._sslobj.write(buffer)
 | 
			
		||||
                count += written
 | 
			
		||||
                buffer = buffer[written:]
 | 
			
		||||
 | 
			
		||||
    def write(self, data):
 | 
			
		||||
        data = str(data) # XXX Should really reject non-string non-buffers
 | 
			
		||||
        if not data:
 | 
			
		||||
            return
 | 
			
		||||
        self._wbuf.append(data)
 | 
			
		||||
        if (self._wbufsize == 0 or
 | 
			
		||||
            self._wbufsize == 1 and '\n' in data or
 | 
			
		||||
            self._get_wbuf_len() >= self._wbufsize):
 | 
			
		||||
            self.flush()
 | 
			
		||||
 | 
			
		||||
    def writelines(self, list):
 | 
			
		||||
        # XXX We could do better here for very long lists
 | 
			
		||||
        # XXX Should really reject non-string non-buffers
 | 
			
		||||
        self._wbuf.extend(filter(None, map(str, list)))
 | 
			
		||||
        if (self._wbufsize <= 1 or
 | 
			
		||||
            self._get_wbuf_len() >= self._wbufsize):
 | 
			
		||||
            self.flush()
 | 
			
		||||
 | 
			
		||||
    def _get_wbuf_len(self):
 | 
			
		||||
        buf_len = 0
 | 
			
		||||
        for x in self._wbuf:
 | 
			
		||||
            buf_len += len(x)
 | 
			
		||||
        return buf_len
 | 
			
		||||
 | 
			
		||||
    def read(self, size=-1):
 | 
			
		||||
        data = self._rbuf
 | 
			
		||||
        if size < 0:
 | 
			
		||||
            # Read until EOF
 | 
			
		||||
            buffers = []
 | 
			
		||||
            if data:
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
            self._rbuf = ""
 | 
			
		||||
            if self._rbufsize <= 1:
 | 
			
		||||
                recv_size = self.default_bufsize
 | 
			
		||||
            else:
 | 
			
		||||
                recv_size = self._rbufsize
 | 
			
		||||
            while True:
 | 
			
		||||
                data = self._sslobj.read(recv_size)
 | 
			
		||||
                if not data:
 | 
			
		||||
                    break
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
            return "".join(buffers)
 | 
			
		||||
        else:
 | 
			
		||||
            # Read until size bytes or EOF seen, whichever comes first
 | 
			
		||||
            buf_len = len(data)
 | 
			
		||||
            if buf_len >= size:
 | 
			
		||||
                self._rbuf = data[size:]
 | 
			
		||||
                return data[:size]
 | 
			
		||||
            buffers = []
 | 
			
		||||
            if data:
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
            self._rbuf = ""
 | 
			
		||||
            while True:
 | 
			
		||||
                left = size - buf_len
 | 
			
		||||
                recv_size = max(self._rbufsize, left)
 | 
			
		||||
                data = self._sslobj.read(recv_size)
 | 
			
		||||
                if not data:
 | 
			
		||||
                    break
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
                n = len(data)
 | 
			
		||||
                if n >= left:
 | 
			
		||||
                    self._rbuf = data[left:]
 | 
			
		||||
                    buffers[-1] = data[:left]
 | 
			
		||||
                    break
 | 
			
		||||
                buf_len += n
 | 
			
		||||
            return "".join(buffers)
 | 
			
		||||
 | 
			
		||||
    def readline(self, size=-1):
 | 
			
		||||
        data = self._rbuf
 | 
			
		||||
        if size < 0:
 | 
			
		||||
            # Read until \n or EOF, whichever comes first
 | 
			
		||||
            if self._rbufsize <= 1:
 | 
			
		||||
                # Speed up unbuffered case
 | 
			
		||||
                assert data == ""
 | 
			
		||||
                buffers = []
 | 
			
		||||
                while data != "\n":
 | 
			
		||||
                    data = self._sslobj.read(1)
 | 
			
		||||
                    if not data:
 | 
			
		||||
                        break
 | 
			
		||||
                    buffers.append(data)
 | 
			
		||||
                return "".join(buffers)
 | 
			
		||||
            nl = data.find('\n')
 | 
			
		||||
            if nl >= 0:
 | 
			
		||||
                nl += 1
 | 
			
		||||
                self._rbuf = data[nl:]
 | 
			
		||||
                return data[:nl]
 | 
			
		||||
            buffers = []
 | 
			
		||||
            if data:
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
            self._rbuf = ""
 | 
			
		||||
            while True:
 | 
			
		||||
                data = self._sslobj.read(self._rbufsize)
 | 
			
		||||
                if not data:
 | 
			
		||||
                    break
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
                nl = data.find('\n')
 | 
			
		||||
                if nl >= 0:
 | 
			
		||||
                    nl += 1
 | 
			
		||||
                    self._rbuf = data[nl:]
 | 
			
		||||
                    buffers[-1] = data[:nl]
 | 
			
		||||
                    break
 | 
			
		||||
            return "".join(buffers)
 | 
			
		||||
        else:
 | 
			
		||||
            # Read until size bytes or \n or EOF seen, whichever comes first
 | 
			
		||||
            nl = data.find('\n', 0, size)
 | 
			
		||||
            if nl >= 0:
 | 
			
		||||
                nl += 1
 | 
			
		||||
                self._rbuf = data[nl:]
 | 
			
		||||
                return data[:nl]
 | 
			
		||||
            buf_len = len(data)
 | 
			
		||||
            if buf_len >= size:
 | 
			
		||||
                self._rbuf = data[size:]
 | 
			
		||||
                return data[:size]
 | 
			
		||||
            buffers = []
 | 
			
		||||
            if data:
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
            self._rbuf = ""
 | 
			
		||||
            while True:
 | 
			
		||||
                data = self._sslobj.read(self._rbufsize)
 | 
			
		||||
                if not data:
 | 
			
		||||
                    break
 | 
			
		||||
                buffers.append(data)
 | 
			
		||||
                left = size - buf_len
 | 
			
		||||
                nl = data.find('\n', 0, left)
 | 
			
		||||
                if nl >= 0:
 | 
			
		||||
                    nl += 1
 | 
			
		||||
                    self._rbuf = data[nl:]
 | 
			
		||||
                    buffers[-1] = data[:nl]
 | 
			
		||||
                    break
 | 
			
		||||
                n = len(data)
 | 
			
		||||
                if n >= left:
 | 
			
		||||
                    self._rbuf = data[left:]
 | 
			
		||||
                    buffers[-1] = data[:left]
 | 
			
		||||
                    break
 | 
			
		||||
                buf_len += n
 | 
			
		||||
            return "".join(buffers)
 | 
			
		||||
 | 
			
		||||
    def readlines(self, sizehint=0):
 | 
			
		||||
        total = 0
 | 
			
		||||
        list = []
 | 
			
		||||
        while True:
 | 
			
		||||
            line = self.readline()
 | 
			
		||||
            if not line:
 | 
			
		||||
                break
 | 
			
		||||
            list.append(line)
 | 
			
		||||
            total += len(line)
 | 
			
		||||
            if sizehint and total >= sizehint:
 | 
			
		||||
                break
 | 
			
		||||
        return list
 | 
			
		||||
 | 
			
		||||
    # Iterator protocols
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def next(self):
 | 
			
		||||
        line = self.readline()
 | 
			
		||||
        if not line:
 | 
			
		||||
            raise StopIteration
 | 
			
		||||
        return line
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return (SSLSocket(sock=newsock,
 | 
			
		||||
                          keyfile=self.keyfile, certfile=self.certfile,
 | 
			
		||||
                          server_side=True,
 | 
			
		||||
                          cert_reqs=self.cert_reqs, ssl_version=self.ssl_version,
 | 
			
		||||
                          ca_certs=self.ca_certs,
 | 
			
		||||
                          do_handshake_on_connect=self.do_handshake_on_connect),
 | 
			
		||||
                addr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wrap_socket(sock, keyfile=None, certfile=None,
 | 
			
		||||
                server_side=False, cert_reqs=CERT_NONE,
 | 
			
		||||
                ssl_version=PROTOCOL_SSLv23, ca_certs=None):
 | 
			
		||||
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
 | 
			
		||||
                do_handshake_on_connect=True):
 | 
			
		||||
 | 
			
		||||
    return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
 | 
			
		||||
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
 | 
			
		||||
                     server_side=server_side, cert_reqs=cert_reqs,
 | 
			
		||||
                     ssl_version=ssl_version, ca_certs=ca_certs)
 | 
			
		||||
                     ssl_version=ssl_version, ca_certs=ca_certs,
 | 
			
		||||
                     do_handshake_on_connect=do_handshake_on_connect)
 | 
			
		||||
 | 
			
		||||
# some utility functions
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -486,16 +377,10 @@ def DER_cert_to_PEM_cert(der_cert_bytes):
 | 
			
		|||
    """Takes a certificate in binary DER format and returns the
 | 
			
		||||
    PEM version of it as a string."""
 | 
			
		||||
 | 
			
		||||
    if hasattr(base64, 'standard_b64encode'):
 | 
			
		||||
        # preferred because older API gets line-length wrong
 | 
			
		||||
        f = base64.standard_b64encode(der_cert_bytes)
 | 
			
		||||
        return (PEM_HEADER + '\n' +
 | 
			
		||||
                textwrap.fill(f, 64) +
 | 
			
		||||
                PEM_FOOTER + '\n')
 | 
			
		||||
    else:
 | 
			
		||||
        return (PEM_HEADER + '\n' +
 | 
			
		||||
                base64.encodestring(der_cert_bytes) +
 | 
			
		||||
                PEM_FOOTER + '\n')
 | 
			
		||||
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
 | 
			
		||||
    return (PEM_HEADER + '\n' +
 | 
			
		||||
            textwrap.fill(f, 64) + '\n' +
 | 
			
		||||
            PEM_FOOTER + '\n')
 | 
			
		||||
 | 
			
		||||
def PEM_cert_to_DER_cert(pem_cert_string):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -509,7 +394,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):
 | 
			
		|||
        raise ValueError("Invalid PEM encoding; must end with %s"
 | 
			
		||||
                         % PEM_FOOTER)
 | 
			
		||||
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
 | 
			
		||||
    return base64.decodestring(d)
 | 
			
		||||
    return base64.decodestring(d.encode('ASCII', 'strict'))
 | 
			
		||||
 | 
			
		||||
def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -541,15 +426,3 @@ def get_protocol_name (protocol_code):
 | 
			
		|||
        return "SSLv3"
 | 
			
		||||
    else:
 | 
			
		||||
        return "<unknown>"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# a replacement for the old socket.ssl function
 | 
			
		||||
 | 
			
		||||
def sslwrap_simple (sock, keyfile=None, certfile=None):
 | 
			
		||||
 | 
			
		||||
    """A replacement for the old socket.ssl function.  Designed
 | 
			
		||||
    for compability with Python 2.5 and earlier.  Will disappear in
 | 
			
		||||
    Python 3.0."""
 | 
			
		||||
 | 
			
		||||
    return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
 | 
			
		||||
                        PROTOCOL_SSLv23, None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@
 | 
			
		|||
import unittest
 | 
			
		||||
from test import test_support
 | 
			
		||||
import socket
 | 
			
		||||
import select
 | 
			
		||||
import errno
 | 
			
		||||
import subprocess
 | 
			
		||||
import time
 | 
			
		||||
| 
						 | 
				
			
			@ -36,27 +37,6 @@ def handle_error(prefix):
 | 
			
		|||
 | 
			
		||||
class BasicTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def testSSLconnect(self):
 | 
			
		||||
        import os
 | 
			
		||||
        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
 | 
			
		||||
                            cert_reqs=ssl.CERT_NONE)
 | 
			
		||||
        s.connect(("svn.python.org", 443))
 | 
			
		||||
        c = s.getpeercert()
 | 
			
		||||
        if c:
 | 
			
		||||
            raise test_support.TestFailed("Peer cert %s shouldn't be here!")
 | 
			
		||||
        s.close()
 | 
			
		||||
 | 
			
		||||
        # this should fail because we have no verification certs
 | 
			
		||||
        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
 | 
			
		||||
                            cert_reqs=ssl.CERT_REQUIRED)
 | 
			
		||||
        try:
 | 
			
		||||
            s.connect(("svn.python.org", 443))
 | 
			
		||||
        except ssl.SSLError:
 | 
			
		||||
            pass
 | 
			
		||||
        finally:
 | 
			
		||||
            s.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def testCrucialConstants(self):
 | 
			
		||||
        ssl.PROTOCOL_SSLv2
 | 
			
		||||
        ssl.PROTOCOL_SSLv23
 | 
			
		||||
| 
						 | 
				
			
			@ -97,11 +77,31 @@ def testDERtoPEM(self):
 | 
			
		|||
        if (d1 != d2):
 | 
			
		||||
            raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
 | 
			
		||||
 | 
			
		||||
class NetworkedTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
class NetworkTests(unittest.TestCase):
 | 
			
		||||
    def testFetchServerCert(self):
 | 
			
		||||
 | 
			
		||||
        pem = ssl.get_server_certificate(("svn.python.org", 443))
 | 
			
		||||
        if not pem:
 | 
			
		||||
            raise test_support.TestFailed("No server certificate on svn.python.org:443!")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
 | 
			
		||||
        except ssl.SSLError as x:
 | 
			
		||||
            #should fail
 | 
			
		||||
            if test_support.verbose:
 | 
			
		||||
                sys.stdout.write("%s\n" % x)
 | 
			
		||||
        else:
 | 
			
		||||
            raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
 | 
			
		||||
 | 
			
		||||
        pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
 | 
			
		||||
        if not pem:
 | 
			
		||||
            raise test_support.TestFailed("No server certificate on svn.python.org:443!")
 | 
			
		||||
        if test_support.verbose:
 | 
			
		||||
            sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
 | 
			
		||||
 | 
			
		||||
    def testConnect(self):
 | 
			
		||||
        import os
 | 
			
		||||
 | 
			
		||||
        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
 | 
			
		||||
                            cert_reqs=ssl.CERT_NONE)
 | 
			
		||||
        s.connect(("svn.python.org", 443))
 | 
			
		||||
| 
						 | 
				
			
			@ -131,25 +131,29 @@ def testConnect(self):
 | 
			
		|||
        finally:
 | 
			
		||||
            s.close()
 | 
			
		||||
 | 
			
		||||
    def testFetchServerCert(self):
 | 
			
		||||
 | 
			
		||||
        pem = ssl.get_server_certificate(("svn.python.org", 443))
 | 
			
		||||
        if not pem:
 | 
			
		||||
            raise test_support.TestFailed("No server certificate on svn.python.org:443!")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
 | 
			
		||||
        except ssl.SSLError:
 | 
			
		||||
            #should fail
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
 | 
			
		||||
 | 
			
		||||
        pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
 | 
			
		||||
        if not pem:
 | 
			
		||||
            raise test_support.TestFailed("No server certificate on svn.python.org:443!")
 | 
			
		||||
    def testNonBlockingHandshake(self):
 | 
			
		||||
        s = socket.socket(socket.AF_INET)
 | 
			
		||||
        s.connect(("svn.python.org", 443))
 | 
			
		||||
        s.setblocking(False)
 | 
			
		||||
        s = ssl.wrap_socket(s,
 | 
			
		||||
                            cert_reqs=ssl.CERT_NONE,
 | 
			
		||||
                            do_handshake_on_connect=False)
 | 
			
		||||
        count = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                count += 1
 | 
			
		||||
                s.do_handshake()
 | 
			
		||||
                break
 | 
			
		||||
            except ssl.SSLError as err:
 | 
			
		||||
                if err.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
			
		||||
                    select.select([s], [], [])
 | 
			
		||||
                elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
 | 
			
		||||
                    select.select([], [s], [])
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
        s.close()
 | 
			
		||||
        if test_support.verbose:
 | 
			
		||||
            sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
 | 
			
		||||
            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -168,10 +172,11 @@ class ConnectionHandler(threading.Thread):
 | 
			
		|||
            with and without the SSL wrapper around the socket connection, so
 | 
			
		||||
            that we can test the STARTTLS functionality."""
 | 
			
		||||
 | 
			
		||||
            def __init__(self, server, connsock):
 | 
			
		||||
            def __init__(self, server, connsock, addr):
 | 
			
		||||
                self.server = server
 | 
			
		||||
                self.running = False
 | 
			
		||||
                self.sock = connsock
 | 
			
		||||
                self.addr = addr
 | 
			
		||||
                self.sock.setblocking(1)
 | 
			
		||||
                self.sslconn = None
 | 
			
		||||
                threading.Thread.__init__(self)
 | 
			
		||||
| 
						 | 
				
			
			@ -186,8 +191,7 @@ def wrap_conn (self):
 | 
			
		|||
                                                   cert_reqs=self.server.certreqs)
 | 
			
		||||
                except:
 | 
			
		||||
                    if self.server.chatty:
 | 
			
		||||
                        handle_error("\n server:  bad connection attempt from " +
 | 
			
		||||
                                     str(self.sock.getpeername()) + ":\n")
 | 
			
		||||
                        handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
 | 
			
		||||
                    if not self.server.expect_bad_connects:
 | 
			
		||||
                        # here, we want to stop the server, because this shouldn't
 | 
			
		||||
                        # happen in the context of our test case
 | 
			
		||||
| 
						 | 
				
			
			@ -195,6 +199,7 @@ def wrap_conn (self):
 | 
			
		|||
                        # normally, we'd just stop here, but for the test
 | 
			
		||||
                        # harness, we want to stop the server
 | 
			
		||||
                        self.server.stop()
 | 
			
		||||
                    self.close()
 | 
			
		||||
                    return False
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
| 
						 | 
				
			
			@ -236,19 +241,21 @@ def run (self):
 | 
			
		|||
                while self.running:
 | 
			
		||||
                    try:
 | 
			
		||||
                        msg = self.read()
 | 
			
		||||
                        amsg = (msg and str(msg, 'ASCII', 'strict')) or ''
 | 
			
		||||
                        if not msg:
 | 
			
		||||
                            # eof, so quit this handler
 | 
			
		||||
                            self.running = False
 | 
			
		||||
                            self.close()
 | 
			
		||||
                        elif msg.strip() == 'over':
 | 
			
		||||
                        elif amsg.strip() == 'over':
 | 
			
		||||
                            if test_support.verbose and self.server.connectionchatty:
 | 
			
		||||
                                sys.stdout.write(" server: client closed connection\n")
 | 
			
		||||
                            self.close()
 | 
			
		||||
                            return
 | 
			
		||||
                        elif self.server.starttls_server and msg.strip() == 'STARTTLS':
 | 
			
		||||
                        elif (self.server.starttls_server and
 | 
			
		||||
                              amsg.strip() == 'STARTTLS'):
 | 
			
		||||
                            if test_support.verbose and self.server.connectionchatty:
 | 
			
		||||
                                sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
 | 
			
		||||
                            self.write("OK\n")
 | 
			
		||||
                            self.write("OK\n".encode("ASCII", "strict"))
 | 
			
		||||
                            if not self.wrap_conn():
 | 
			
		||||
                                return
 | 
			
		||||
                        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -257,8 +264,8 @@ def run (self):
 | 
			
		|||
                                ctype = (self.sslconn and "encrypted") or "unencrypted"
 | 
			
		||||
                                sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n"
 | 
			
		||||
                                                 % (repr(msg), ctype, repr(msg.lower()), ctype))
 | 
			
		||||
                            self.write(msg.lower())
 | 
			
		||||
                    except ssl.SSLError:
 | 
			
		||||
                            self.write(amsg.lower().encode('ASCII', 'strict'))
 | 
			
		||||
                    except socket.error:
 | 
			
		||||
                        if self.server.chatty:
 | 
			
		||||
                            handle_error("Test server failure:\n")
 | 
			
		||||
                        self.close()
 | 
			
		||||
| 
						 | 
				
			
			@ -311,8 +318,8 @@ def run (self):
 | 
			
		|||
                    newconn, connaddr = self.sock.accept()
 | 
			
		||||
                    if test_support.verbose and self.chatty:
 | 
			
		||||
                        sys.stdout.write(' server:  new connection from '
 | 
			
		||||
                                         + str(connaddr) + '\n')
 | 
			
		||||
                    handler = self.ConnectionHandler(self, newconn)
 | 
			
		||||
                                         + repr(connaddr) + '\n')
 | 
			
		||||
                    handler = self.ConnectionHandler(self, newconn, connaddr)
 | 
			
		||||
                    handler.start()
 | 
			
		||||
                except socket.timeout:
 | 
			
		||||
                    pass
 | 
			
		||||
| 
						 | 
				
			
			@ -321,11 +328,10 @@ def run (self):
 | 
			
		|||
                except:
 | 
			
		||||
                    if self.chatty:
 | 
			
		||||
                        handle_error("Test server failure:\n")
 | 
			
		||||
            self.sock.close()
 | 
			
		||||
 | 
			
		||||
        def stop (self):
 | 
			
		||||
            self.active = False
 | 
			
		||||
            self.sock.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class AsyncoreHTTPSServer(threading.Thread):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -339,6 +345,12 @@ def __init__(self, server_address, RequestHandlerClass, certfile):
 | 
			
		|||
                self.active = False
 | 
			
		||||
                self.allow_reuse_address = True
 | 
			
		||||
 | 
			
		||||
            def __str__(self):
 | 
			
		||||
                return ('<%s %s:%s>' %
 | 
			
		||||
                        (self.__class__.__name__,
 | 
			
		||||
                         self.server_name,
 | 
			
		||||
                         self.server_port))
 | 
			
		||||
 | 
			
		||||
            def get_request (self):
 | 
			
		||||
                # override this to wrap socket with SSL
 | 
			
		||||
                sock, addr = self.socket.accept()
 | 
			
		||||
| 
						 | 
				
			
			@ -415,8 +427,8 @@ def log_message(self, format, *args):
 | 
			
		|||
                # we override this to suppress logging unless "verbose"
 | 
			
		||||
 | 
			
		||||
                if test_support.verbose:
 | 
			
		||||
                    sys.stdout.write(" server (%s, %d, %s):\n   [%s] %s\n" %
 | 
			
		||||
                                     (self.server.server_name,
 | 
			
		||||
                    sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
 | 
			
		||||
                                     (self.server.server_address,
 | 
			
		||||
                                      self.server.server_port,
 | 
			
		||||
                                      self.request.cipher(),
 | 
			
		||||
                                      self.log_date_time_string(),
 | 
			
		||||
| 
						 | 
				
			
			@ -433,9 +445,7 @@ def __init__(self, port, certfile):
 | 
			
		|||
            self.setDaemon(True)
 | 
			
		||||
 | 
			
		||||
        def __str__(self):
 | 
			
		||||
            return '<%s %s:%d>' % (self.__class__.__name__,
 | 
			
		||||
                                   self.server.server_name,
 | 
			
		||||
                                   self.server.server_port)
 | 
			
		||||
            return "<%s %s>" % (self.__class__.__name__, self.server)
 | 
			
		||||
 | 
			
		||||
        def start (self, flag=None):
 | 
			
		||||
            self.flag = flag
 | 
			
		||||
| 
						 | 
				
			
			@ -456,7 +466,8 @@ def stop (self):
 | 
			
		|||
    def badCertTest (certfile):
 | 
			
		||||
        server = ThreadedEchoServer(TESTPORT, CERTFILE,
 | 
			
		||||
                                    certreqs=ssl.CERT_REQUIRED,
 | 
			
		||||
                                    cacerts=CERTFILE, chatty=False)
 | 
			
		||||
                                    cacerts=CERTFILE, chatty=False,
 | 
			
		||||
                                    connectionchatty=False)
 | 
			
		||||
        flag = threading.Event()
 | 
			
		||||
        server.start(flag)
 | 
			
		||||
        # wait for it to start
 | 
			
		||||
| 
						 | 
				
			
			@ -470,7 +481,7 @@ def badCertTest (certfile):
 | 
			
		|||
                s.connect(('127.0.0.1', TESTPORT))
 | 
			
		||||
            except ssl.SSLError as x:
 | 
			
		||||
                if test_support.verbose:
 | 
			
		||||
                    sys.stdout.write("\nSSLError is %s\n" % x[1])
 | 
			
		||||
                    sys.stdout.write("\nSSLError is %s\n" % x)
 | 
			
		||||
            else:
 | 
			
		||||
                raise test_support.TestFailed(
 | 
			
		||||
                    "Use of invalid cert should have failed!")
 | 
			
		||||
| 
						 | 
				
			
			@ -479,15 +490,16 @@ def badCertTest (certfile):
 | 
			
		|||
            server.join()
 | 
			
		||||
 | 
			
		||||
    def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
 | 
			
		||||
                          client_certfile, client_protocol=None, indata="FOO\n",
 | 
			
		||||
                          chatty=True, connectionchatty=False):
 | 
			
		||||
                          client_certfile, client_protocol=None,
 | 
			
		||||
                          indata="FOO\n",
 | 
			
		||||
                          chatty=False, connectionchatty=False):
 | 
			
		||||
 | 
			
		||||
        server = ThreadedEchoServer(TESTPORT, certfile,
 | 
			
		||||
                                    certreqs=certreqs,
 | 
			
		||||
                                    ssl_version=protocol,
 | 
			
		||||
                                    cacerts=cacertsfile,
 | 
			
		||||
                                    chatty=chatty,
 | 
			
		||||
                                    connectionchatty=connectionchatty)
 | 
			
		||||
                                    connectionchatty=False)
 | 
			
		||||
        flag = threading.Event()
 | 
			
		||||
        server.start(flag)
 | 
			
		||||
        # wait for it to start
 | 
			
		||||
| 
						 | 
				
			
			@ -496,37 +508,37 @@ def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
 | 
			
		|||
        if client_protocol is None:
 | 
			
		||||
            client_protocol = protocol
 | 
			
		||||
        try:
 | 
			
		||||
            try:
 | 
			
		||||
                s = ssl.wrap_socket(socket.socket(),
 | 
			
		||||
                                    certfile=client_certfile,
 | 
			
		||||
                                    ca_certs=cacertsfile,
 | 
			
		||||
                                    cert_reqs=certreqs,
 | 
			
		||||
                                    ssl_version=client_protocol)
 | 
			
		||||
                s.connect(('127.0.0.1', TESTPORT))
 | 
			
		||||
            except ssl.SSLError as x:
 | 
			
		||||
                raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
 | 
			
		||||
            except Exception as x:
 | 
			
		||||
                raise test_support.TestFailed("Unexpected exception:  " + str(x))
 | 
			
		||||
            else:
 | 
			
		||||
                if connectionchatty:
 | 
			
		||||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write(
 | 
			
		||||
                            " client:  sending %s...\n" % (repr(indata)))
 | 
			
		||||
                s.write(indata)
 | 
			
		||||
                outdata = s.read()
 | 
			
		||||
                if connectionchatty:
 | 
			
		||||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write(" client:  read %s\n" % repr(outdata))
 | 
			
		||||
                if outdata != indata.lower():
 | 
			
		||||
                    raise test_support.TestFailed(
 | 
			
		||||
                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
 | 
			
		||||
                        % (outdata[:min(len(outdata),20)], len(outdata),
 | 
			
		||||
                           indata[:min(len(indata),20)].lower(), len(indata)))
 | 
			
		||||
                s.write("over\n")
 | 
			
		||||
                if connectionchatty:
 | 
			
		||||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write(" client:  closing connection.\n")
 | 
			
		||||
                s.close()
 | 
			
		||||
            s = ssl.wrap_socket(socket.socket(),
 | 
			
		||||
                                certfile=client_certfile,
 | 
			
		||||
                                ca_certs=cacertsfile,
 | 
			
		||||
                                cert_reqs=certreqs,
 | 
			
		||||
                                ssl_version=client_protocol)
 | 
			
		||||
            s.connect(('127.0.0.1', TESTPORT))
 | 
			
		||||
        except ssl.SSLError as x:
 | 
			
		||||
            raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
 | 
			
		||||
        except Exception as x:
 | 
			
		||||
            raise test_support.TestFailed("Unexpected exception:  " + str(x))
 | 
			
		||||
        else:
 | 
			
		||||
            if connectionchatty:
 | 
			
		||||
                if test_support.verbose:
 | 
			
		||||
                    sys.stdout.write(
 | 
			
		||||
                        " client:  sending %s...\n" % (repr(indata)))
 | 
			
		||||
            s.write(indata.encode('ASCII', 'strict'))
 | 
			
		||||
            outdata = s.read()
 | 
			
		||||
            if connectionchatty:
 | 
			
		||||
                if test_support.verbose:
 | 
			
		||||
                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
 | 
			
		||||
            outdata = str(outdata, 'ASCII', 'strict')
 | 
			
		||||
            if outdata != indata.lower():
 | 
			
		||||
                raise test_support.TestFailed(
 | 
			
		||||
                    "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
 | 
			
		||||
                    % (repr(outdata[:min(len(outdata),20)]), len(outdata),
 | 
			
		||||
                       repr(indata[:min(len(indata),20)].lower()), len(indata)))
 | 
			
		||||
            s.write("over\n".encode("ASCII", "strict"))
 | 
			
		||||
            if connectionchatty:
 | 
			
		||||
                if test_support.verbose:
 | 
			
		||||
                    sys.stdout.write(" client:  closing connection.\n")
 | 
			
		||||
            s.close()
 | 
			
		||||
        finally:
 | 
			
		||||
            server.stop()
 | 
			
		||||
            server.join()
 | 
			
		||||
| 
						 | 
				
			
			@ -553,7 +565,8 @@ def tryProtocolCombo (server_protocol,
 | 
			
		|||
                              certtype))
 | 
			
		||||
        try:
 | 
			
		||||
            serverParamsTest(CERTFILE, server_protocol, certsreqs,
 | 
			
		||||
                             CERTFILE, CERTFILE, client_protocol, chatty=False)
 | 
			
		||||
                             CERTFILE, CERTFILE, client_protocol,
 | 
			
		||||
                             chatty=False, connectionchatty=False)
 | 
			
		||||
        except test_support.TestFailed:
 | 
			
		||||
            if expectedToWork:
 | 
			
		||||
                raise
 | 
			
		||||
| 
						 | 
				
			
			@ -565,47 +578,7 @@ def tryProtocolCombo (server_protocol,
 | 
			
		|||
                       ssl.get_protocol_name(server_protocol)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class ConnectedTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        def testRudeShutdown(self):
 | 
			
		||||
 | 
			
		||||
            listener_ready = threading.Event()
 | 
			
		||||
            listener_gone = threading.Event()
 | 
			
		||||
 | 
			
		||||
            # `listener` runs in a thread.  It opens a socket listening on
 | 
			
		||||
            # PORT, and sits in an accept() until the main thread connects.
 | 
			
		||||
            # Then it rudely closes the socket, and sets Event `listener_gone`
 | 
			
		||||
            # to let the main thread know the socket is gone.
 | 
			
		||||
            def listener():
 | 
			
		||||
                s = socket.socket()
 | 
			
		||||
                if hasattr(socket, 'SO_REUSEADDR'):
 | 
			
		||||
                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | 
			
		||||
                if hasattr(socket, 'SO_REUSEPORT'):
 | 
			
		||||
                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
 | 
			
		||||
                s.bind(('127.0.0.1', TESTPORT))
 | 
			
		||||
                s.listen(5)
 | 
			
		||||
                listener_ready.set()
 | 
			
		||||
                s.accept()
 | 
			
		||||
                s = None # reclaim the socket object, which also closes it
 | 
			
		||||
                listener_gone.set()
 | 
			
		||||
 | 
			
		||||
            def connector():
 | 
			
		||||
                listener_ready.wait()
 | 
			
		||||
                s = socket.socket()
 | 
			
		||||
                s.connect(('127.0.0.1', TESTPORT))
 | 
			
		||||
                listener_gone.wait()
 | 
			
		||||
                try:
 | 
			
		||||
                    ssl_sock = ssl.wrap_socket(s)
 | 
			
		||||
                except socket.sslerror:
 | 
			
		||||
                    pass
 | 
			
		||||
                else:
 | 
			
		||||
                    raise test_support.TestFailed(
 | 
			
		||||
                          'connecting to closed SSL socket should have failed')
 | 
			
		||||
 | 
			
		||||
            t = threading.Thread(target=listener)
 | 
			
		||||
            t.start()
 | 
			
		||||
            connector()
 | 
			
		||||
            t.join()
 | 
			
		||||
    class ThreadedTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        def testEcho (self):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -656,7 +629,7 @@ def testReadCert(self):
 | 
			
		|||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write(pprint.pformat(cert) + '\n')
 | 
			
		||||
                        sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
 | 
			
		||||
                    if not cert.has_key('subject'):
 | 
			
		||||
                    if 'subject' not in cert:
 | 
			
		||||
                        raise test_support.TestFailed(
 | 
			
		||||
                            "No subject field in certificate: %s." %
 | 
			
		||||
                            pprint.pformat(cert))
 | 
			
		||||
| 
						 | 
				
			
			@ -680,6 +653,46 @@ def testMalformedKey(self):
 | 
			
		|||
            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
 | 
			
		||||
                                     "badkey.pem"))
 | 
			
		||||
 | 
			
		||||
        def testRudeShutdown(self):
 | 
			
		||||
 | 
			
		||||
            listener_ready = threading.Event()
 | 
			
		||||
            listener_gone = threading.Event()
 | 
			
		||||
 | 
			
		||||
            # `listener` runs in a thread.  It opens a socket listening on
 | 
			
		||||
            # PORT, and sits in an accept() until the main thread connects.
 | 
			
		||||
            # Then it rudely closes the socket, and sets Event `listener_gone`
 | 
			
		||||
            # to let the main thread know the socket is gone.
 | 
			
		||||
            def listener():
 | 
			
		||||
                s = socket.socket()
 | 
			
		||||
                if hasattr(socket, 'SO_REUSEADDR'):
 | 
			
		||||
                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | 
			
		||||
                if hasattr(socket, 'SO_REUSEPORT'):
 | 
			
		||||
                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
 | 
			
		||||
                s.bind(('127.0.0.1', TESTPORT))
 | 
			
		||||
                s.listen(5)
 | 
			
		||||
                listener_ready.set()
 | 
			
		||||
                s.accept()
 | 
			
		||||
                s = None # reclaim the socket object, which also closes it
 | 
			
		||||
                listener_gone.set()
 | 
			
		||||
 | 
			
		||||
            def connector():
 | 
			
		||||
                listener_ready.wait()
 | 
			
		||||
                s = socket.socket()
 | 
			
		||||
                s.connect(('127.0.0.1', TESTPORT))
 | 
			
		||||
                listener_gone.wait()
 | 
			
		||||
                try:
 | 
			
		||||
                    ssl_sock = ssl.wrap_socket(s)
 | 
			
		||||
                except IOError:
 | 
			
		||||
                    pass
 | 
			
		||||
                else:
 | 
			
		||||
                    raise test_support.TestFailed(
 | 
			
		||||
                          'connecting to closed SSL socket should have failed')
 | 
			
		||||
 | 
			
		||||
            t = threading.Thread(target=listener)
 | 
			
		||||
            t.start()
 | 
			
		||||
            connector()
 | 
			
		||||
            t.join()
 | 
			
		||||
 | 
			
		||||
        def testProtocolSSL2(self):
 | 
			
		||||
            if test_support.verbose:
 | 
			
		||||
                sys.stdout.write("\n")
 | 
			
		||||
| 
						 | 
				
			
			@ -759,39 +772,47 @@ def testSTARTTLS (self):
 | 
			
		|||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write("\n")
 | 
			
		||||
                    for indata in msgs:
 | 
			
		||||
                        msg = indata.encode('ASCII', 'replace')
 | 
			
		||||
                        if test_support.verbose:
 | 
			
		||||
                            sys.stdout.write(
 | 
			
		||||
                                " client:  sending %s...\n" % repr(indata))
 | 
			
		||||
                                " client:  sending %s...\n" % repr(msg))
 | 
			
		||||
                        if wrapped:
 | 
			
		||||
                            conn.write(indata)
 | 
			
		||||
                            conn.write(msg)
 | 
			
		||||
                            outdata = conn.read()
 | 
			
		||||
                        else:
 | 
			
		||||
                            s.send(indata)
 | 
			
		||||
                            s.send(msg)
 | 
			
		||||
                            outdata = s.recv(1024)
 | 
			
		||||
                        if (indata == "STARTTLS" and
 | 
			
		||||
                            outdata.strip().lower().startswith("ok")):
 | 
			
		||||
                            str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
 | 
			
		||||
                            if test_support.verbose:
 | 
			
		||||
                                msg = str(outdata, 'ASCII', 'replace')
 | 
			
		||||
                                sys.stdout.write(
 | 
			
		||||
                                    " client:  read %s from server, starting TLS...\n"
 | 
			
		||||
                                    % repr(outdata))
 | 
			
		||||
                                    % repr(msg))
 | 
			
		||||
                            conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
 | 
			
		||||
 | 
			
		||||
                            wrapped = True
 | 
			
		||||
                        else:
 | 
			
		||||
                            if test_support.verbose:
 | 
			
		||||
                                msg = str(outdata, 'ASCII', 'replace')
 | 
			
		||||
                                sys.stdout.write(
 | 
			
		||||
                                    " client:  read %s from server\n" % repr(outdata))
 | 
			
		||||
                                    " client:  read %s from server\n" % repr(msg))
 | 
			
		||||
                    if test_support.verbose:
 | 
			
		||||
                        sys.stdout.write(" client:  closing connection.\n")
 | 
			
		||||
                    if wrapped:
 | 
			
		||||
                        conn.write("over\n")
 | 
			
		||||
                        conn.write("over\n".encode("ASCII", "strict"))
 | 
			
		||||
                    else:
 | 
			
		||||
                        s.send("over\n")
 | 
			
		||||
                if wrapped:
 | 
			
		||||
                    conn.close()
 | 
			
		||||
                else:
 | 
			
		||||
                    s.close()
 | 
			
		||||
            finally:
 | 
			
		||||
                server.stop()
 | 
			
		||||
                server.join()
 | 
			
		||||
 | 
			
		||||
    class AsyncoreTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        def testAsyncore(self):
 | 
			
		||||
 | 
			
		||||
            server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
 | 
			
		||||
| 
						 | 
				
			
			@ -824,6 +845,8 @@ def testAsyncore(self):
 | 
			
		|||
                raise test_support.TestFailed(msg)
 | 
			
		||||
            else:
 | 
			
		||||
                if not (d1 == d2):
 | 
			
		||||
                    print("d1 is", len(d1), repr(d1))
 | 
			
		||||
                    print("d2 is", len(d2), repr(d2))
 | 
			
		||||
                    raise test_support.TestFailed(
 | 
			
		||||
                        "Couldn't fetch data from HTTPS server")
 | 
			
		||||
            finally:
 | 
			
		||||
| 
						 | 
				
			
			@ -863,6 +886,7 @@ def test_main(verbose=False):
 | 
			
		|||
    if (not os.path.exists(CERTFILE) or
 | 
			
		||||
        not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
 | 
			
		||||
        raise test_support.TestFailed("Can't read certificate files!")
 | 
			
		||||
 | 
			
		||||
    TESTPORT = findtestsocket(10025, 12000)
 | 
			
		||||
    if not TESTPORT:
 | 
			
		||||
        raise test_support.TestFailed("Can't find open port to test servers on!")
 | 
			
		||||
| 
						 | 
				
			
			@ -870,12 +894,13 @@ def test_main(verbose=False):
 | 
			
		|||
    tests = [BasicTests]
 | 
			
		||||
 | 
			
		||||
    if test_support.is_resource_enabled('network'):
 | 
			
		||||
        tests.append(NetworkTests)
 | 
			
		||||
        tests.append(NetworkedTests)
 | 
			
		||||
 | 
			
		||||
    if _have_threads:
 | 
			
		||||
        thread_info = test_support.threading_setup()
 | 
			
		||||
        if thread_info and test_support.is_resource_enabled('network'):
 | 
			
		||||
            tests.append(ConnectedTests)
 | 
			
		||||
            tests.append(ThreadedTests)
 | 
			
		||||
            tests.append(AsyncoreTests)
 | 
			
		||||
 | 
			
		||||
    test_support.run_unittest(*tests)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										328
									
								
								Modules/_ssl.c
									
										
									
									
									
								
							
							
						
						
									
										328
									
								
								Modules/_ssl.c
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -2,14 +2,15 @@
 | 
			
		|||
 | 
			
		||||
   SSL support based on patches by Brian E Gallew and Laszlo Kovacs.
 | 
			
		||||
   Re-worked a bit by Bill Janssen to add server-side support and
 | 
			
		||||
   certificate decoding.
 | 
			
		||||
   certificate decoding.  Chris Stawarz contributed some non-blocking
 | 
			
		||||
   patches.
 | 
			
		||||
 | 
			
		||||
   This module is imported by ssl.py. It should *not* be used
 | 
			
		||||
   directly.
 | 
			
		||||
 | 
			
		||||
   XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE?
 | 
			
		||||
 | 
			
		||||
   XXX what about SSL_MODE_AUTO_RETRY
 | 
			
		||||
   XXX what about SSL_MODE_AUTO_RETRY?
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#include "Python.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -17,7 +18,7 @@
 | 
			
		|||
#ifdef WITH_THREAD
 | 
			
		||||
#include "pythread.h"
 | 
			
		||||
#define PySSL_BEGIN_ALLOW_THREADS { \
 | 
			
		||||
			PyThreadState *_save;  \
 | 
			
		||||
			PyThreadState *_save = NULL;  \
 | 
			
		||||
			if (_ssl_locks_count>0) {_save = PyEval_SaveThread();}
 | 
			
		||||
#define PySSL_BLOCK_THREADS	if (_ssl_locks_count>0){PyEval_RestoreThread(_save)};
 | 
			
		||||
#define PySSL_UNBLOCK_THREADS	if (_ssl_locks_count>0){_save = PyEval_SaveThread()};
 | 
			
		||||
| 
						 | 
				
			
			@ -114,8 +115,6 @@ typedef struct {
 | 
			
		|||
	SSL_CTX*	ctx;
 | 
			
		||||
	SSL*		ssl;
 | 
			
		||||
	X509*		peer_cert;
 | 
			
		||||
	char		server[X509_NAME_MAXLEN];
 | 
			
		||||
	char		issuer[X509_NAME_MAXLEN];
 | 
			
		||||
 | 
			
		||||
} PySSLObject;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -265,15 +264,11 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
 | 
			
		|||
	PySSLObject *self;
 | 
			
		||||
	char *errstr = NULL;
 | 
			
		||||
	int ret;
 | 
			
		||||
	int err;
 | 
			
		||||
	int sockstate;
 | 
			
		||||
	int verification_mode;
 | 
			
		||||
 | 
			
		||||
	self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
 | 
			
		||||
	if (self == NULL)
 | 
			
		||||
		return NULL;
 | 
			
		||||
	memset(self->server, '\0', sizeof(char) * X509_NAME_MAXLEN);
 | 
			
		||||
	memset(self->issuer, '\0', sizeof(char) * X509_NAME_MAXLEN);
 | 
			
		||||
	self->peer_cert = NULL;
 | 
			
		||||
	self->ssl = NULL;
 | 
			
		||||
	self->ctx = NULL;
 | 
			
		||||
| 
						 | 
				
			
			@ -388,57 +383,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
 | 
			
		|||
		SSL_set_accept_state(self->ssl);
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
 | 
			
		||||
	/* Actually negotiate SSL connection */
 | 
			
		||||
	/* XXX If SSL_connect() returns 0, it's also a failure. */
 | 
			
		||||
	sockstate = 0;
 | 
			
		||||
	do {
 | 
			
		||||
		PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
		if (socket_type == PY_SSL_CLIENT)
 | 
			
		||||
			ret = SSL_connect(self->ssl);
 | 
			
		||||
		else
 | 
			
		||||
			ret = SSL_accept(self->ssl);
 | 
			
		||||
		err = SSL_get_error(self->ssl, ret);
 | 
			
		||||
		PySSL_END_ALLOW_THREADS
 | 
			
		||||
		if(PyErr_CheckSignals()) {
 | 
			
		||||
			goto fail;
 | 
			
		||||
		}
 | 
			
		||||
		if (err == SSL_ERROR_WANT_READ) {
 | 
			
		||||
			sockstate = check_socket_and_wait_for_timeout(Sock, 0);
 | 
			
		||||
		} else if (err == SSL_ERROR_WANT_WRITE) {
 | 
			
		||||
			sockstate = check_socket_and_wait_for_timeout(Sock, 1);
 | 
			
		||||
		} else {
 | 
			
		||||
			sockstate = SOCKET_OPERATION_OK;
 | 
			
		||||
		}
 | 
			
		||||
		if (sockstate == SOCKET_HAS_TIMED_OUT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
				ERRSTR("The connect operation timed out"));
 | 
			
		||||
			goto fail;
 | 
			
		||||
		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
				ERRSTR("Underlying socket has been closed."));
 | 
			
		||||
			goto fail;
 | 
			
		||||
		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
			  ERRSTR("Underlying socket too large for select()."));
 | 
			
		||||
			goto fail;
 | 
			
		||||
		} else if (sockstate == SOCKET_IS_NONBLOCKING) {
 | 
			
		||||
			break;
 | 
			
		||||
		}
 | 
			
		||||
	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
 | 
			
		||||
	if (ret < 1) {
 | 
			
		||||
		PySSL_SetError(self, ret, __FILE__, __LINE__);
 | 
			
		||||
		goto fail;
 | 
			
		||||
	}
 | 
			
		||||
	self->ssl->debug = 1;
 | 
			
		||||
 | 
			
		||||
	PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
	if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
 | 
			
		||||
		X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
 | 
			
		||||
				  self->server, X509_NAME_MAXLEN);
 | 
			
		||||
		X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
 | 
			
		||||
				  self->issuer, X509_NAME_MAXLEN);
 | 
			
		||||
	}
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
	self->Socket = Sock;
 | 
			
		||||
	Py_INCREF(self->Socket);
 | 
			
		||||
	return self;
 | 
			
		||||
| 
						 | 
				
			
			@ -488,16 +432,58 @@ PyDoc_STRVAR(ssl_doc,
 | 
			
		|||
 | 
			
		||||
/* SSL object methods */
 | 
			
		||||
 | 
			
		||||
static PyObject *
 | 
			
		||||
PySSL_server(PySSLObject *self)
 | 
			
		||||
static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
 | 
			
		||||
{
 | 
			
		||||
	return PyUnicode_FromString(self->server);
 | 
			
		||||
}
 | 
			
		||||
	int ret;
 | 
			
		||||
	int err;
 | 
			
		||||
	int sockstate;
 | 
			
		||||
 | 
			
		||||
static PyObject *
 | 
			
		||||
PySSL_issuer(PySSLObject *self)
 | 
			
		||||
{
 | 
			
		||||
	return PyUnicode_FromString(self->issuer);
 | 
			
		||||
	/* Actually negotiate SSL connection */
 | 
			
		||||
	/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
 | 
			
		||||
	sockstate = 0;
 | 
			
		||||
	do {
 | 
			
		||||
		PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
		ret = SSL_do_handshake(self->ssl);
 | 
			
		||||
		err = SSL_get_error(self->ssl, ret);
 | 
			
		||||
		PySSL_END_ALLOW_THREADS
 | 
			
		||||
		if(PyErr_CheckSignals()) {
 | 
			
		||||
			return NULL;
 | 
			
		||||
		}
 | 
			
		||||
		if (err == SSL_ERROR_WANT_READ) {
 | 
			
		||||
			sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
 | 
			
		||||
		} else if (err == SSL_ERROR_WANT_WRITE) {
 | 
			
		||||
			sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
 | 
			
		||||
		} else {
 | 
			
		||||
			sockstate = SOCKET_OPERATION_OK;
 | 
			
		||||
		}
 | 
			
		||||
		if (sockstate == SOCKET_HAS_TIMED_OUT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
				ERRSTR("The handshake operation timed out"));
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
				ERRSTR("Underlying socket has been closed."));
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
			  ERRSTR("Underlying socket too large for select()."));
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_IS_NONBLOCKING) {
 | 
			
		||||
			break;
 | 
			
		||||
		}
 | 
			
		||||
	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
 | 
			
		||||
	if (ret < 1)
 | 
			
		||||
		return PySSL_SetError(self, ret, __FILE__, __LINE__);
 | 
			
		||||
	self->ssl->debug = 1;
 | 
			
		||||
 | 
			
		||||
	if (self->peer_cert)
 | 
			
		||||
		X509_free (self->peer_cert);
 | 
			
		||||
        PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
	self->peer_cert = SSL_get_peer_certificate(self->ssl);
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
 | 
			
		||||
	Py_INCREF(Py_None);
 | 
			
		||||
	return Py_None;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PyObject *
 | 
			
		||||
| 
						 | 
				
			
			@ -515,7 +501,7 @@ _create_tuple_for_attribute (ASN1_OBJECT *name, ASN1_STRING *value) {
 | 
			
		|||
		_setSSLError(NULL, 0, __FILE__, __LINE__);
 | 
			
		||||
		goto fail;
 | 
			
		||||
	}
 | 
			
		||||
	name_obj = PyString_FromStringAndSize(namebuf, buflen);
 | 
			
		||||
	name_obj = PyUnicode_FromStringAndSize(namebuf, buflen);
 | 
			
		||||
	if (name_obj == NULL)
 | 
			
		||||
		goto fail;
 | 
			
		||||
	
 | 
			
		||||
| 
						 | 
				
			
			@ -681,21 +667,24 @@ _get_peer_alt_names (X509 *certificate) {
 | 
			
		|||
		/* now decode the altName */
 | 
			
		||||
		ext = X509_get_ext(certificate, i);
 | 
			
		||||
		if(!(method = X509V3_EXT_get(ext))) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
					ERRSTR("No method for internalizing subjectAltName!"));
 | 
			
		||||
			PyErr_SetString
 | 
			
		||||
                          (PySSLErrorObject,
 | 
			
		||||
                           ERRSTR("No method for internalizing subjectAltName!"));
 | 
			
		||||
			goto fail;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p = ext->value->data;
 | 
			
		||||
		if(method->it)
 | 
			
		||||
			names = (GENERAL_NAMES*) (ASN1_item_d2i(NULL,
 | 
			
		||||
								&p,
 | 
			
		||||
								ext->value->length,
 | 
			
		||||
								ASN1_ITEM_ptr(method->it)));
 | 
			
		||||
			names = (GENERAL_NAMES*)
 | 
			
		||||
                          (ASN1_item_d2i(NULL,
 | 
			
		||||
                                         &p,
 | 
			
		||||
                                         ext->value->length,
 | 
			
		||||
                                         ASN1_ITEM_ptr(method->it)));
 | 
			
		||||
		else
 | 
			
		||||
			names = (GENERAL_NAMES*) (method->d2i(NULL,
 | 
			
		||||
							      &p,
 | 
			
		||||
							      ext->value->length));
 | 
			
		||||
			names = (GENERAL_NAMES*)
 | 
			
		||||
                          (method->d2i(NULL,
 | 
			
		||||
                                       &p,
 | 
			
		||||
                                       ext->value->length));
 | 
			
		||||
 | 
			
		||||
		for(j = 0; j < sk_GENERAL_NAME_num(names); j++) {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -704,14 +693,15 @@ _get_peer_alt_names (X509 *certificate) {
 | 
			
		|||
			name = sk_GENERAL_NAME_value(names, j);
 | 
			
		||||
			if (name->type == GEN_DIRNAME) {
 | 
			
		||||
 | 
			
		||||
				/* we special-case DirName as a tuple of tuples of attributes */
 | 
			
		||||
				/* we special-case DirName as a tuple of
 | 
			
		||||
                                   tuples of attributes */
 | 
			
		||||
 | 
			
		||||
				t = PyTuple_New(2);
 | 
			
		||||
				if (t == NULL) {
 | 
			
		||||
					goto fail;
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				v = PyString_FromString("DirName");
 | 
			
		||||
				v = PyUnicode_FromString("DirName");
 | 
			
		||||
				if (v == NULL) {
 | 
			
		||||
					Py_DECREF(t);
 | 
			
		||||
					goto fail;
 | 
			
		||||
| 
						 | 
				
			
			@ -742,13 +732,14 @@ _get_peer_alt_names (X509 *certificate) {
 | 
			
		|||
				t = PyTuple_New(2);
 | 
			
		||||
				if (t == NULL)
 | 
			
		||||
					goto fail;
 | 
			
		||||
				v = PyString_FromStringAndSize(buf, (vptr - buf));
 | 
			
		||||
				v = PyUnicode_FromStringAndSize(buf, (vptr - buf));
 | 
			
		||||
				if (v == NULL) {
 | 
			
		||||
					Py_DECREF(t);
 | 
			
		||||
					goto fail;
 | 
			
		||||
				}
 | 
			
		||||
				PyTuple_SET_ITEM(t, 0, v);
 | 
			
		||||
				v = PyString_FromStringAndSize((vptr + 1), (len - (vptr - buf + 1)));
 | 
			
		||||
				v = PyUnicode_FromStringAndSize((vptr + 1),
 | 
			
		||||
                                                                (len - (vptr - buf + 1)));
 | 
			
		||||
				if (v == NULL) {
 | 
			
		||||
					Py_DECREF(t);
 | 
			
		||||
					goto fail;
 | 
			
		||||
| 
						 | 
				
			
			@ -849,7 +840,7 @@ _decode_certificate (X509 *certificate, int verbose) {
 | 
			
		|||
			_setSSLError(NULL, 0, __FILE__, __LINE__);
 | 
			
		||||
			goto fail1;
 | 
			
		||||
		}
 | 
			
		||||
		sn_obj = PyString_FromStringAndSize(buf, len);
 | 
			
		||||
		sn_obj = PyUnicode_FromStringAndSize(buf, len);
 | 
			
		||||
		if (sn_obj == NULL)
 | 
			
		||||
			goto fail1;
 | 
			
		||||
		if (PyDict_SetItemString(retval, "serialNumber", sn_obj) < 0) {
 | 
			
		||||
| 
						 | 
				
			
			@ -866,7 +857,7 @@ _decode_certificate (X509 *certificate, int verbose) {
 | 
			
		|||
			_setSSLError(NULL, 0, __FILE__, __LINE__);
 | 
			
		||||
			goto fail1;
 | 
			
		||||
		}
 | 
			
		||||
		pnotBefore = PyString_FromStringAndSize(buf, len);
 | 
			
		||||
		pnotBefore = PyUnicode_FromStringAndSize(buf, len);
 | 
			
		||||
		if (pnotBefore == NULL)
 | 
			
		||||
			goto fail1;
 | 
			
		||||
		if (PyDict_SetItemString(retval, "notBefore", pnotBefore) < 0) {
 | 
			
		||||
| 
						 | 
				
			
			@ -884,7 +875,7 @@ _decode_certificate (X509 *certificate, int verbose) {
 | 
			
		|||
		_setSSLError(NULL, 0, __FILE__, __LINE__);
 | 
			
		||||
		goto fail1;
 | 
			
		||||
	}
 | 
			
		||||
	pnotAfter = PyString_FromStringAndSize(buf, len);
 | 
			
		||||
	pnotAfter = PyUnicode_FromStringAndSize(buf, len);
 | 
			
		||||
	if (pnotAfter == NULL)
 | 
			
		||||
		goto fail1;
 | 
			
		||||
	if (PyDict_SetItemString(retval, "notAfter", pnotAfter) < 0) {
 | 
			
		||||
| 
						 | 
				
			
			@ -928,22 +919,26 @@ PySSL_test_decode_certificate (PyObject *mod, PyObject *args) {
 | 
			
		|||
	BIO *cert;
 | 
			
		||||
	int verbose = 1;
 | 
			
		||||
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "s|i:test_decode_certificate", &filename, &verbose))
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "s|i:test_decode_certificate",
 | 
			
		||||
                              &filename, &verbose))
 | 
			
		||||
		return NULL;
 | 
			
		||||
 | 
			
		||||
	if ((cert=BIO_new(BIO_s_file())) == NULL) {
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject, "Can't malloc memory to read file");
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
                                "Can't malloc memory to read file");
 | 
			
		||||
		goto fail0;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (BIO_read_filename(cert,filename) <= 0) {
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject, "Can't open file");
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
                                "Can't open file");
 | 
			
		||||
		goto fail0;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	x = PEM_read_bio_X509_AUX(cert,NULL, NULL, NULL);
 | 
			
		||||
	if (x == NULL) {
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject, "Error decoding PEM-encoded file");
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
                                "Error decoding PEM-encoded file");
 | 
			
		||||
		goto fail0;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -981,7 +976,9 @@ PySSL_peercert(PySSLObject *self, PyObject *args)
 | 
			
		|||
			PySSL_SetError(self, len, __FILE__, __LINE__);
 | 
			
		||||
			return NULL;
 | 
			
		||||
		}
 | 
			
		||||
		retval = PyString_FromStringAndSize((const char *) bytes_buf, len);
 | 
			
		||||
                /* this is actually an immutable bytes sequence */
 | 
			
		||||
		retval = PyBytes_FromStringAndSize
 | 
			
		||||
                  ((const char *) bytes_buf, len);
 | 
			
		||||
		OPENSSL_free(bytes_buf);
 | 
			
		||||
		return retval;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1028,7 +1025,7 @@ static PyObject *PySSL_cipher (PySSLObject *self) {
 | 
			
		|||
	if (cipher_name == NULL) {
 | 
			
		||||
		PyTuple_SET_ITEM(retval, 0, Py_None);
 | 
			
		||||
	} else {
 | 
			
		||||
		v = PyString_FromString(cipher_name);
 | 
			
		||||
		v = PyUnicode_FromString(cipher_name);
 | 
			
		||||
		if (v == NULL)
 | 
			
		||||
			goto fail0;
 | 
			
		||||
		PyTuple_SET_ITEM(retval, 0, v);
 | 
			
		||||
| 
						 | 
				
			
			@ -1037,7 +1034,7 @@ static PyObject *PySSL_cipher (PySSLObject *self) {
 | 
			
		|||
	if (cipher_protocol == NULL) {
 | 
			
		||||
		PyTuple_SET_ITEM(retval, 1, Py_None);
 | 
			
		||||
	} else {
 | 
			
		||||
		v = PyString_FromString(cipher_protocol);
 | 
			
		||||
		v = PyUnicode_FromString(cipher_protocol);
 | 
			
		||||
		if (v == NULL)
 | 
			
		||||
			goto fail0;
 | 
			
		||||
		PyTuple_SET_ITEM(retval, 1, v);
 | 
			
		||||
| 
						 | 
				
			
			@ -1127,7 +1124,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing)
 | 
			
		|||
		rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
 | 
			
		||||
#ifdef HAVE_POLL
 | 
			
		||||
normal_return:
 | 
			
		||||
#endif
 | 
			
		||||
	/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
 | 
			
		||||
	   (when we are able to write or when there's something to read) */
 | 
			
		||||
	return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK;
 | 
			
		||||
| 
						 | 
				
			
			@ -1140,10 +1139,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
 | 
			
		|||
	int count;
 | 
			
		||||
	int sockstate;
 | 
			
		||||
	int err;
 | 
			
		||||
        int nonblocking;
 | 
			
		||||
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "s#:write", &data, &count))
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
 | 
			
		||||
		return NULL;
 | 
			
		||||
 | 
			
		||||
        /* just in case the blocking state of the socket has been changed */
 | 
			
		||||
	nonblocking = (self->Socket->sock_timeout >= 0.0);
 | 
			
		||||
        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
 | 
			
		||||
        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
 | 
			
		||||
 | 
			
		||||
	sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
 | 
			
		||||
	if (sockstate == SOCKET_HAS_TIMED_OUT) {
 | 
			
		||||
		PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
| 
						 | 
				
			
			@ -1200,19 +1205,58 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc,
 | 
			
		|||
Writes the string s into the SSL object.  Returns the number\n\
 | 
			
		||||
of bytes written.");
 | 
			
		||||
 | 
			
		||||
static PyObject *PySSL_SSLpending(PySSLObject *self)
 | 
			
		||||
{
 | 
			
		||||
	int count = 0;
 | 
			
		||||
 | 
			
		||||
	PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
	count = SSL_pending(self->ssl);
 | 
			
		||||
	PySSL_END_ALLOW_THREADS
 | 
			
		||||
	if (count < 0)
 | 
			
		||||
		return PySSL_SetError(self, count, __FILE__, __LINE__);
 | 
			
		||||
	else
 | 
			
		||||
		return PyInt_FromLong(count);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PyDoc_STRVAR(PySSL_SSLpending_doc,
 | 
			
		||||
"pending() -> count\n\
 | 
			
		||||
\n\
 | 
			
		||||
Returns the number of already decrypted bytes available for read,\n\
 | 
			
		||||
pending on the connection.\n");
 | 
			
		||||
 | 
			
		||||
static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
 | 
			
		||||
{
 | 
			
		||||
	PyObject *buf;
 | 
			
		||||
	int count = 0;
 | 
			
		||||
	PyObject *buf = NULL;
 | 
			
		||||
	int buf_passed = 0;
 | 
			
		||||
	int count = -1;
 | 
			
		||||
	int len = 1024;
 | 
			
		||||
	int sockstate;
 | 
			
		||||
	int err;
 | 
			
		||||
        int nonblocking;
 | 
			
		||||
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "|i:read", &len))
 | 
			
		||||
	if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
 | 
			
		||||
		return NULL;
 | 
			
		||||
 | 
			
		||||
	if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
 | 
			
		||||
		return NULL;
 | 
			
		||||
        if ((buf == NULL) || (buf == Py_None)) {
 | 
			
		||||
		if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
 | 
			
		||||
			return NULL;
 | 
			
		||||
        } else if (PyInt_Check(buf)) {
 | 
			
		||||
		len = PyInt_AS_LONG(buf);
 | 
			
		||||
		if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
 | 
			
		||||
			return NULL;
 | 
			
		||||
	} else {
 | 
			
		||||
		if (!PyBytes_Check(buf))
 | 
			
		||||
			return NULL;
 | 
			
		||||
		len = PyBytes_Size(buf);
 | 
			
		||||
		if ((count > 0) && (count <= len))
 | 
			
		||||
			len = count;
 | 
			
		||||
		buf_passed = 1;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
        /* just in case the blocking state of the socket has been changed */
 | 
			
		||||
	nonblocking = (self->Socket->sock_timeout >= 0.0);
 | 
			
		||||
        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
 | 
			
		||||
        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
 | 
			
		||||
 | 
			
		||||
	/* first check if there are bytes ready to be read */
 | 
			
		||||
	PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
| 
						 | 
				
			
			@ -1224,27 +1268,38 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
 | 
			
		|||
		if (sockstate == SOCKET_HAS_TIMED_OUT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
					"The read operation timed out");
 | 
			
		||||
			Py_DECREF(buf);
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				Py_DECREF(buf);
 | 
			
		||||
			}
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
				"Underlying socket too large for select().");
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				Py_DECREF(buf);
 | 
			
		||||
			}
 | 
			
		||||
			Py_DECREF(buf);
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
 | 
			
		||||
			/* should contain a zero-length string */
 | 
			
		||||
			_PyString_Resize(&buf, 0);
 | 
			
		||||
			return buf;
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				PyBytes_Resize(buf, 0);
 | 
			
		||||
				return buf;
 | 
			
		||||
			} else {
 | 
			
		||||
				return PyInt_FromLong(0);
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	do {
 | 
			
		||||
		err = 0;
 | 
			
		||||
		PySSL_BEGIN_ALLOW_THREADS
 | 
			
		||||
		count = SSL_read(self->ssl, PyBytes_AS_STRING(buf), len);
 | 
			
		||||
		count = SSL_read(self->ssl, PyBytes_AsString(buf), len);
 | 
			
		||||
		err = SSL_get_error(self->ssl, count);
 | 
			
		||||
		PySSL_END_ALLOW_THREADS
 | 
			
		||||
		if(PyErr_CheckSignals()) {
 | 
			
		||||
			Py_DECREF(buf);
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				Py_DECREF(buf);
 | 
			
		||||
			}
 | 
			
		||||
			return NULL;
 | 
			
		||||
		}
 | 
			
		||||
		if (err == SSL_ERROR_WANT_READ) {
 | 
			
		||||
| 
						 | 
				
			
			@ -1257,44 +1312,55 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
 | 
			
		|||
			   (SSL_get_shutdown(self->ssl) ==
 | 
			
		||||
			    SSL_RECEIVED_SHUTDOWN))
 | 
			
		||||
		{
 | 
			
		||||
			_PyString_Resize(&buf, 0);
 | 
			
		||||
			return buf;
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				PyBytes_Resize(buf, 0);
 | 
			
		||||
				return buf;
 | 
			
		||||
			} else {
 | 
			
		||||
				return PyInt_FromLong(0);
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			sockstate = SOCKET_OPERATION_OK;
 | 
			
		||||
		}
 | 
			
		||||
		if (sockstate == SOCKET_HAS_TIMED_OUT) {
 | 
			
		||||
			PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
					"The read operation timed out");
 | 
			
		||||
			Py_DECREF(buf);
 | 
			
		||||
			if (!buf_passed) {
 | 
			
		||||
				Py_DECREF(buf);
 | 
			
		||||
			}
 | 
			
		||||
			return NULL;
 | 
			
		||||
		} else if (sockstate == SOCKET_IS_NONBLOCKING) {
 | 
			
		||||
			break;
 | 
			
		||||
		}
 | 
			
		||||
	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
 | 
			
		||||
	if (count <= 0) {
 | 
			
		||||
		Py_DECREF(buf);
 | 
			
		||||
		if (!buf_passed) {
 | 
			
		||||
			Py_DECREF(buf);
 | 
			
		||||
		}
 | 
			
		||||
		return PySSL_SetError(self, count, __FILE__, __LINE__);
 | 
			
		||||
	}
 | 
			
		||||
	if (count != len)
 | 
			
		||||
		if (PyBytes_Resize(buf, count) < 0) {
 | 
			
		||||
                        Py_DECREF(buf);
 | 
			
		||||
                        return NULL;
 | 
			
		||||
                }
 | 
			
		||||
	return buf;
 | 
			
		||||
	if (!buf_passed) {
 | 
			
		||||
		if (count != len) {
 | 
			
		||||
			PyBytes_Resize(buf, count);
 | 
			
		||||
		}
 | 
			
		||||
		return buf;
 | 
			
		||||
	} else {
 | 
			
		||||
		return PyInt_FromLong(count);
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PyDoc_STRVAR(PySSL_SSLread_doc,
 | 
			
		||||
"read([len]) -> bytes\n\
 | 
			
		||||
"read([len]) -> string\n\
 | 
			
		||||
\n\
 | 
			
		||||
Read up to len bytes from the SSL socket.");
 | 
			
		||||
 | 
			
		||||
static PyMethodDef PySSLMethods[] = {
 | 
			
		||||
	{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
 | 
			
		||||
	{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
 | 
			
		||||
	 PySSL_SSLwrite_doc},
 | 
			
		||||
	{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS,
 | 
			
		||||
	 PySSL_SSLread_doc},
 | 
			
		||||
	{"server", (PyCFunction)PySSL_server, METH_NOARGS},
 | 
			
		||||
	{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS},
 | 
			
		||||
	{"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS,
 | 
			
		||||
	 PySSL_SSLpending_doc},
 | 
			
		||||
	{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
 | 
			
		||||
	 PySSL_peercert_doc},
 | 
			
		||||
	{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
 | 
			
		||||
| 
						 | 
				
			
			@ -1350,26 +1416,26 @@ bound on the entropy contained in string.  See RFC 1750.");
 | 
			
		|||
static PyObject *
 | 
			
		||||
PySSL_RAND_status(PyObject *self)
 | 
			
		||||
{
 | 
			
		||||
    return PyBool_FromLong(RAND_status());
 | 
			
		||||
    return PyInt_FromLong(RAND_status());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PyDoc_STRVAR(PySSL_RAND_status_doc,
 | 
			
		||||
"RAND_status() -> 0 or 1\n\
 | 
			
		||||
\n\
 | 
			
		||||
Returns True if the OpenSSL PRNG has been seeded with enough data and\n\
 | 
			
		||||
False if not.  It is necessary to seed the PRNG with RAND_add()\n\
 | 
			
		||||
on some platforms before using the ssl() function.");
 | 
			
		||||
Returns 1 if the OpenSSL PRNG has been seeded with enough data and 0 if not.\n\
 | 
			
		||||
It is necessary to seed the PRNG with RAND_add() on some platforms before\n\
 | 
			
		||||
using the ssl() function.");
 | 
			
		||||
 | 
			
		||||
static PyObject *
 | 
			
		||||
PySSL_RAND_egd(PyObject *self, PyObject *arg)
 | 
			
		||||
{
 | 
			
		||||
    int bytes;
 | 
			
		||||
 | 
			
		||||
    if (!PyString_Check(arg))
 | 
			
		||||
    if (!PyUnicode_Check(arg))
 | 
			
		||||
	return PyErr_Format(PyExc_TypeError,
 | 
			
		||||
			    "RAND_egd() expected string, found %s",
 | 
			
		||||
			    Py_Type(arg)->tp_name);
 | 
			
		||||
    bytes = RAND_egd(PyString_AS_STRING(arg));
 | 
			
		||||
    bytes = RAND_egd(PyUnicode_AsString(arg));
 | 
			
		||||
    if (bytes == -1) {
 | 
			
		||||
	PyErr_SetString(PySSLErrorObject,
 | 
			
		||||
			"EGD connection failed or EGD did not return "
 | 
			
		||||
| 
						 | 
				
			
			@ -1418,16 +1484,17 @@ static unsigned long _ssl_thread_id_function (void) {
 | 
			
		|||
	return PyThread_get_thread_ident();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void _ssl_thread_locking_function (int mode, int n, const char *file, int line) {
 | 
			
		||||
static void _ssl_thread_locking_function
 | 
			
		||||
        (int mode, int n, const char *file, int line) {
 | 
			
		||||
	/* this function is needed to perform locking on shared data
 | 
			
		||||
	   structures. (Note that OpenSSL uses a number of global data
 | 
			
		||||
	   structures that will be implicitly shared whenever multiple threads
 | 
			
		||||
	   use OpenSSL.) Multi-threaded applications will crash at random if
 | 
			
		||||
	   it is not set.
 | 
			
		||||
	   structures that will be implicitly shared whenever multiple
 | 
			
		||||
	   threads use OpenSSL.) Multi-threaded applications will
 | 
			
		||||
	   crash at random if it is not set.
 | 
			
		||||
 | 
			
		||||
	   locking_function() must be able to handle up to CRYPTO_num_locks()
 | 
			
		||||
	   different mutex locks. It sets the n-th lock if mode & CRYPTO_LOCK, and
 | 
			
		||||
	   releases it otherwise.
 | 
			
		||||
	   locking_function() must be able to handle up to
 | 
			
		||||
	   CRYPTO_num_locks() different mutex locks. It sets the n-th
 | 
			
		||||
	   lock if mode & CRYPTO_LOCK, and releases it otherwise.
 | 
			
		||||
 | 
			
		||||
	   file and line are the file number of the function setting the
 | 
			
		||||
	   lock. They can be useful for debugging.
 | 
			
		||||
| 
						 | 
				
			
			@ -1454,7 +1521,8 @@ static int _setup_ssl_threads(void) {
 | 
			
		|||
			malloc(sizeof(PyThread_type_lock) * _ssl_locks_count);
 | 
			
		||||
		if (_ssl_locks == NULL)
 | 
			
		||||
			return 0;
 | 
			
		||||
		memset(_ssl_locks, 0, sizeof(PyThread_type_lock) * _ssl_locks_count);
 | 
			
		||||
		memset(_ssl_locks, 0,
 | 
			
		||||
                       sizeof(PyThread_type_lock) * _ssl_locks_count);
 | 
			
		||||
		for (i = 0;  i < _ssl_locks_count;  i++) {
 | 
			
		||||
			_ssl_locks[i] = PyThread_allocate_lock();
 | 
			
		||||
			if (_ssl_locks[i] == NULL) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue