mirror of
https://github.com/python/cpython.git
synced 2025-12-08 06:10:17 +00:00
[3.10] Raise TypeError if SSLSocket is passed to asyncio transport-based methods (GH-31442). (GH-31443)
(cherry picked from commit 1f9d4c93af)
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
ea3e0421b0
commit
dde048819f
3 changed files with 22 additions and 10 deletions
|
|
@ -202,6 +202,11 @@ def _set_nodelay(sock):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_ssl_socket(sock):
|
||||||
|
if ssl is not None and isinstance(sock, ssl.SSLSocket):
|
||||||
|
raise TypeError("Socket cannot be of type SSLSocket")
|
||||||
|
|
||||||
|
|
||||||
class _SendfileFallbackProtocol(protocols.Protocol):
|
class _SendfileFallbackProtocol(protocols.Protocol):
|
||||||
def __init__(self, transp):
|
def __init__(self, transp):
|
||||||
if not isinstance(transp, transports._FlowControlMixin):
|
if not isinstance(transp, transports._FlowControlMixin):
|
||||||
|
|
@ -863,6 +868,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
|
||||||
*, fallback=True):
|
*, fallback=True):
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
|
_check_ssl_socket(sock)
|
||||||
self._check_sendfile_params(sock, file, offset, count)
|
self._check_sendfile_params(sock, file, offset, count)
|
||||||
try:
|
try:
|
||||||
return await self._sock_sendfile_native(sock, file,
|
return await self._sock_sendfile_native(sock, file,
|
||||||
|
|
@ -1004,6 +1010,9 @@ async def create_connection(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'ssl_handshake_timeout is only meaningful with ssl')
|
'ssl_handshake_timeout is only meaningful with ssl')
|
||||||
|
|
||||||
|
if sock is not None:
|
||||||
|
_check_ssl_socket(sock)
|
||||||
|
|
||||||
if happy_eyeballs_delay is not None and interleave is None:
|
if happy_eyeballs_delay is not None and interleave is None:
|
||||||
# If using happy eyeballs, default to interleave addresses by family
|
# If using happy eyeballs, default to interleave addresses by family
|
||||||
interleave = 1
|
interleave = 1
|
||||||
|
|
@ -1437,6 +1446,9 @@ async def create_server(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'ssl_handshake_timeout is only meaningful with ssl')
|
'ssl_handshake_timeout is only meaningful with ssl')
|
||||||
|
|
||||||
|
if sock is not None:
|
||||||
|
_check_ssl_socket(sock)
|
||||||
|
|
||||||
if host is not None or port is not None:
|
if host is not None or port is not None:
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -1531,6 +1543,9 @@ async def connect_accepted_socket(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'ssl_handshake_timeout is only meaningful with ssl')
|
'ssl_handshake_timeout is only meaningful with ssl')
|
||||||
|
|
||||||
|
if sock is not None:
|
||||||
|
_check_ssl_socket(sock)
|
||||||
|
|
||||||
transport, protocol = await self._create_connection_transport(
|
transport, protocol = await self._create_connection_transport(
|
||||||
sock, protocol_factory, ssl, '', server_side=True,
|
sock, protocol_factory, ssl, '', server_side=True,
|
||||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
|
||||||
return bool(key.events & event)
|
return bool(key.events & event)
|
||||||
|
|
||||||
|
|
||||||
def _check_ssl_socket(sock):
|
|
||||||
if ssl is not None and isinstance(sock, ssl.SSLSocket):
|
|
||||||
raise TypeError("Socket cannot be of type SSLSocket")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
"""Selector event loop.
|
"""Selector event loop.
|
||||||
|
|
||||||
|
|
@ -357,7 +352,7 @@ async def sock_recv(self, sock, n):
|
||||||
The maximum amount of data to be received at once is specified by
|
The maximum amount of data to be received at once is specified by
|
||||||
nbytes.
|
nbytes.
|
||||||
"""
|
"""
|
||||||
_check_ssl_socket(sock)
|
base_events._check_ssl_socket(sock)
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
try:
|
try:
|
||||||
|
|
@ -398,7 +393,7 @@ async def sock_recv_into(self, sock, buf):
|
||||||
The received data is written into *buf* (a writable buffer).
|
The received data is written into *buf* (a writable buffer).
|
||||||
The return value is the number of bytes written.
|
The return value is the number of bytes written.
|
||||||
"""
|
"""
|
||||||
_check_ssl_socket(sock)
|
base_events._check_ssl_socket(sock)
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
try:
|
try:
|
||||||
|
|
@ -439,7 +434,7 @@ async def sock_sendall(self, sock, data):
|
||||||
raised, and there is no way to determine how much data, if any, was
|
raised, and there is no way to determine how much data, if any, was
|
||||||
successfully processed by the receiving end of the connection.
|
successfully processed by the receiving end of the connection.
|
||||||
"""
|
"""
|
||||||
_check_ssl_socket(sock)
|
base_events._check_ssl_socket(sock)
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
try:
|
try:
|
||||||
|
|
@ -488,7 +483,7 @@ async def sock_connect(self, sock, address):
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
_check_ssl_socket(sock)
|
base_events._check_ssl_socket(sock)
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
|
|
||||||
|
|
@ -553,7 +548,7 @@ async def sock_accept(self, sock):
|
||||||
object usable to send and receive data on the connection, and address
|
object usable to send and receive data on the connection, and address
|
||||||
is the address bound to the socket on the other end of the connection.
|
is the address bound to the socket on the other end of the connection.
|
||||||
"""
|
"""
|
||||||
_check_ssl_socket(sock)
|
base_events._check_ssl_socket(sock)
|
||||||
if self._debug and sock.gettimeout() != 0:
|
if self._debug and sock.gettimeout() != 0:
|
||||||
raise ValueError("the socket must be non-blocking")
|
raise ValueError("the socket must be non-blocking")
|
||||||
fut = self.create_future()
|
fut = self.create_future()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
|
||||||
|
transport-based APIs.
|
||||||
Loading…
Add table
Add a link
Reference in a new issue