[3.10] Raise TypeError if SSLSocket is passed to asyncio transport-based methods (GH-31442). (GH-31443)

(cherry picked from commit 1f9d4c93af)

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Andrew Svetlov 2022-02-20 14:45:02 +02:00 committed by GitHub
parent ea3e0421b0
commit dde048819f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 10 deletions

View file

@ -202,6 +202,11 @@ def _set_nodelay(sock):
pass pass
def _check_ssl_socket(sock):
if ssl is not None and isinstance(sock, ssl.SSLSocket):
raise TypeError("Socket cannot be of type SSLSocket")
class _SendfileFallbackProtocol(protocols.Protocol): class _SendfileFallbackProtocol(protocols.Protocol):
def __init__(self, transp): def __init__(self, transp):
if not isinstance(transp, transports._FlowControlMixin): if not isinstance(transp, transports._FlowControlMixin):
@ -863,6 +868,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
*, fallback=True): *, fallback=True):
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
_check_ssl_socket(sock)
self._check_sendfile_params(sock, file, offset, count) self._check_sendfile_params(sock, file, offset, count)
try: try:
return await self._sock_sendfile_native(sock, file, return await self._sock_sendfile_native(sock, file,
@ -1004,6 +1010,9 @@ async def create_connection(
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if sock is not None:
_check_ssl_socket(sock)
if happy_eyeballs_delay is not None and interleave is None: if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family # If using happy eyeballs, default to interleave addresses by family
interleave = 1 interleave = 1
@ -1437,6 +1446,9 @@ async def create_server(
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if sock is not None:
_check_ssl_socket(sock)
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
@ -1531,6 +1543,9 @@ async def connect_accepted_socket(
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if sock is not None:
_check_ssl_socket(sock)
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True, sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout) ssl_handshake_timeout=ssl_handshake_timeout)

View file

@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
return bool(key.events & event) return bool(key.events & event)
def _check_ssl_socket(sock):
if ssl is not None and isinstance(sock, ssl.SSLSocket):
raise TypeError("Socket cannot be of type SSLSocket")
class BaseSelectorEventLoop(base_events.BaseEventLoop): class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""Selector event loop. """Selector event loop.
@ -357,7 +352,7 @@ async def sock_recv(self, sock, n):
The maximum amount of data to be received at once is specified by The maximum amount of data to be received at once is specified by
nbytes. nbytes.
""" """
_check_ssl_socket(sock) base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
try: try:
@ -398,7 +393,7 @@ async def sock_recv_into(self, sock, buf):
The received data is written into *buf* (a writable buffer). The received data is written into *buf* (a writable buffer).
The return value is the number of bytes written. The return value is the number of bytes written.
""" """
_check_ssl_socket(sock) base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
try: try:
@ -439,7 +434,7 @@ async def sock_sendall(self, sock, data):
raised, and there is no way to determine how much data, if any, was raised, and there is no way to determine how much data, if any, was
successfully processed by the receiving end of the connection. successfully processed by the receiving end of the connection.
""" """
_check_ssl_socket(sock) base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
try: try:
@ -488,7 +483,7 @@ async def sock_connect(self, sock, address):
This method is a coroutine. This method is a coroutine.
""" """
_check_ssl_socket(sock) base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
@ -553,7 +548,7 @@ async def sock_accept(self, sock):
object usable to send and receive data on the connection, and address object usable to send and receive data on the connection, and address
is the address bound to the socket on the other end of the connection. is the address bound to the socket on the other end of the connection.
""" """
_check_ssl_socket(sock) base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = self.create_future() fut = self.create_future()

View file

@ -0,0 +1,2 @@
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
transport-based APIs.