mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 05:31:20 +00:00 
			
		
		
		
	 95bcbcb7c8
			
		
	
	
		95bcbcb7c8
		
			
		
	
	
	
	
		
			
			gh-118950: Fix SSLProtocol.connection_lost not being called when OSError is thrown (GH-118960)
(cherry picked from commit 3f24bde0b6)
Co-authored-by: Javad Shafique <javadshafique@hotmail.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
		
	
			
		
			
				
	
	
		
			929 lines
		
	
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			929 lines
		
	
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0
 | |
| # SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0)
 | |
| # SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc.  http://magic.io
 | |
| 
 | |
| import collections
 | |
| import enum
 | |
| import warnings
 | |
| try:
 | |
|     import ssl
 | |
| except ImportError:  # pragma: no cover
 | |
|     ssl = None
 | |
| 
 | |
| from . import constants
 | |
| from . import exceptions
 | |
| from . import protocols
 | |
| from . import transports
 | |
| from .log import logger
 | |
| 
 | |
| if ssl is not None:
 | |
|     SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
 | |
| 
 | |
| 
 | |
| class SSLProtocolState(enum.Enum):
 | |
|     UNWRAPPED = "UNWRAPPED"
 | |
|     DO_HANDSHAKE = "DO_HANDSHAKE"
 | |
|     WRAPPED = "WRAPPED"
 | |
|     FLUSHING = "FLUSHING"
 | |
|     SHUTDOWN = "SHUTDOWN"
 | |
| 
 | |
| 
 | |
| class AppProtocolState(enum.Enum):
 | |
|     # This tracks the state of app protocol (https://git.io/fj59P):
 | |
|     #
 | |
|     #     INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
 | |
|     #
 | |
|     # * cm: connection_made()
 | |
|     # * dr: data_received()
 | |
|     # * er: eof_received()
 | |
|     # * cl: connection_lost()
 | |
| 
 | |
|     STATE_INIT = "STATE_INIT"
 | |
|     STATE_CON_MADE = "STATE_CON_MADE"
 | |
|     STATE_EOF = "STATE_EOF"
 | |
|     STATE_CON_LOST = "STATE_CON_LOST"
 | |
| 
 | |
| 
 | |
| def _create_transport_context(server_side, server_hostname):
 | |
|     if server_side:
 | |
|         raise ValueError('Server side SSL needs a valid SSLContext')
 | |
| 
 | |
|     # Client side may pass ssl=True to use a default
 | |
|     # context; in that case the sslcontext passed is None.
 | |
|     # The default is secure for client connections.
 | |
|     # Python 3.4+: use up-to-date strong settings.
 | |
|     sslcontext = ssl.create_default_context()
 | |
|     if not server_hostname:
 | |
|         sslcontext.check_hostname = False
 | |
|     return sslcontext
 | |
| 
 | |
| 
 | |
| def add_flowcontrol_defaults(high, low, kb):
 | |
|     if high is None:
 | |
|         if low is None:
 | |
|             hi = kb * 1024
 | |
|         else:
 | |
|             lo = low
 | |
|             hi = 4 * lo
 | |
|     else:
 | |
|         hi = high
 | |
|     if low is None:
 | |
|         lo = hi // 4
 | |
|     else:
 | |
|         lo = low
 | |
| 
 | |
|     if not hi >= lo >= 0:
 | |
|         raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
 | |
|                          (hi, lo))
 | |
| 
 | |
|     return hi, lo
 | |
| 
 | |
| 
 | |
| class _SSLProtocolTransport(transports._FlowControlMixin,
 | |
|                             transports.Transport):
 | |
| 
 | |
|     _start_tls_compatible = True
 | |
|     _sendfile_compatible = constants._SendfileMode.FALLBACK
 | |
| 
 | |
|     def __init__(self, loop, ssl_protocol):
 | |
|         self._loop = loop
 | |
|         self._ssl_protocol = ssl_protocol
 | |
|         self._closed = False
 | |
| 
 | |
|     def get_extra_info(self, name, default=None):
 | |
|         """Get optional transport information."""
 | |
|         return self._ssl_protocol._get_extra_info(name, default)
 | |
| 
 | |
|     def set_protocol(self, protocol):
 | |
|         self._ssl_protocol._set_app_protocol(protocol)
 | |
| 
 | |
|     def get_protocol(self):
 | |
|         return self._ssl_protocol._app_protocol
 | |
| 
 | |
|     def is_closing(self):
 | |
|         return self._closed or self._ssl_protocol._is_transport_closing()
 | |
| 
 | |
|     def close(self):
 | |
|         """Close the transport.
 | |
| 
 | |
|         Buffered data will be flushed asynchronously.  No more data
 | |
|         will be received.  After all buffered data is flushed, the
 | |
|         protocol's connection_lost() method will (eventually) called
 | |
|         with None as its argument.
 | |
|         """
 | |
|         if not self._closed:
 | |
|             self._closed = True
 | |
|             self._ssl_protocol._start_shutdown()
 | |
|         else:
 | |
|             self._ssl_protocol = None
 | |
| 
 | |
|     def __del__(self, _warnings=warnings):
 | |
|         if not self._closed:
 | |
|             self._closed = True
 | |
|             _warnings.warn(
 | |
|                 "unclosed transport <asyncio._SSLProtocolTransport "
 | |
|                 "object>", ResourceWarning)
 | |
| 
 | |
|     def is_reading(self):
 | |
|         return not self._ssl_protocol._app_reading_paused
 | |
| 
 | |
|     def pause_reading(self):
 | |
|         """Pause the receiving end.
 | |
| 
 | |
|         No data will be passed to the protocol's data_received()
 | |
|         method until resume_reading() is called.
 | |
|         """
 | |
|         self._ssl_protocol._pause_reading()
 | |
| 
 | |
|     def resume_reading(self):
 | |
|         """Resume the receiving end.
 | |
| 
 | |
|         Data received will once again be passed to the protocol's
 | |
|         data_received() method.
 | |
|         """
 | |
|         self._ssl_protocol._resume_reading()
 | |
| 
 | |
|     def set_write_buffer_limits(self, high=None, low=None):
 | |
|         """Set the high- and low-water limits for write flow control.
 | |
| 
 | |
|         These two values control when to call the protocol's
 | |
|         pause_writing() and resume_writing() methods.  If specified,
 | |
|         the low-water limit must be less than or equal to the
 | |
|         high-water limit.  Neither value can be negative.
 | |
| 
 | |
|         The defaults are implementation-specific.  If only the
 | |
|         high-water limit is given, the low-water limit defaults to an
 | |
|         implementation-specific value less than or equal to the
 | |
|         high-water limit.  Setting high to zero forces low to zero as
 | |
|         well, and causes pause_writing() to be called whenever the
 | |
|         buffer becomes non-empty.  Setting low to zero causes
 | |
|         resume_writing() to be called only once the buffer is empty.
 | |
|         Use of zero for either limit is generally sub-optimal as it
 | |
|         reduces opportunities for doing I/O and computation
 | |
|         concurrently.
 | |
|         """
 | |
|         self._ssl_protocol._set_write_buffer_limits(high, low)
 | |
|         self._ssl_protocol._control_app_writing()
 | |
| 
 | |
|     def get_write_buffer_limits(self):
 | |
|         return (self._ssl_protocol._outgoing_low_water,
 | |
|                 self._ssl_protocol._outgoing_high_water)
 | |
| 
 | |
|     def get_write_buffer_size(self):
 | |
|         """Return the current size of the write buffers."""
 | |
|         return self._ssl_protocol._get_write_buffer_size()
 | |
| 
 | |
|     def set_read_buffer_limits(self, high=None, low=None):
 | |
|         """Set the high- and low-water limits for read flow control.
 | |
| 
 | |
|         These two values control when to call the upstream transport's
 | |
|         pause_reading() and resume_reading() methods.  If specified,
 | |
|         the low-water limit must be less than or equal to the
 | |
|         high-water limit.  Neither value can be negative.
 | |
| 
 | |
|         The defaults are implementation-specific.  If only the
 | |
|         high-water limit is given, the low-water limit defaults to an
 | |
|         implementation-specific value less than or equal to the
 | |
|         high-water limit.  Setting high to zero forces low to zero as
 | |
|         well, and causes pause_reading() to be called whenever the
 | |
|         buffer becomes non-empty.  Setting low to zero causes
 | |
|         resume_reading() to be called only once the buffer is empty.
 | |
|         Use of zero for either limit is generally sub-optimal as it
 | |
|         reduces opportunities for doing I/O and computation
 | |
|         concurrently.
 | |
|         """
 | |
|         self._ssl_protocol._set_read_buffer_limits(high, low)
 | |
|         self._ssl_protocol._control_ssl_reading()
 | |
| 
 | |
|     def get_read_buffer_limits(self):
 | |
|         return (self._ssl_protocol._incoming_low_water,
 | |
|                 self._ssl_protocol._incoming_high_water)
 | |
| 
 | |
|     def get_read_buffer_size(self):
 | |
|         """Return the current size of the read buffer."""
 | |
|         return self._ssl_protocol._get_read_buffer_size()
 | |
| 
 | |
|     @property
 | |
|     def _protocol_paused(self):
 | |
|         # Required for sendfile fallback pause_writing/resume_writing logic
 | |
|         return self._ssl_protocol._app_writing_paused
 | |
| 
 | |
|     def write(self, data):
 | |
|         """Write some data bytes to the transport.
 | |
| 
 | |
|         This does not block; it buffers the data and arranges for it
 | |
|         to be sent out asynchronously.
 | |
|         """
 | |
|         if not isinstance(data, (bytes, bytearray, memoryview)):
 | |
|             raise TypeError(f"data: expecting a bytes-like instance, "
 | |
|                             f"got {type(data).__name__}")
 | |
|         if not data:
 | |
|             return
 | |
|         self._ssl_protocol._write_appdata((data,))
 | |
| 
 | |
|     def writelines(self, list_of_data):
 | |
|         """Write a list (or any iterable) of data bytes to the transport.
 | |
| 
 | |
|         The default implementation concatenates the arguments and
 | |
|         calls write() on the result.
 | |
|         """
 | |
|         self._ssl_protocol._write_appdata(list_of_data)
 | |
| 
 | |
|     def write_eof(self):
 | |
|         """Close the write end after flushing buffered data.
 | |
| 
 | |
|         This raises :exc:`NotImplementedError` right now.
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def can_write_eof(self):
 | |
|         """Return True if this transport supports write_eof(), False if not."""
 | |
|         return False
 | |
| 
 | |
|     def abort(self):
 | |
|         """Close the transport immediately.
 | |
| 
 | |
|         Buffered data will be lost.  No more data will be received.
 | |
|         The protocol's connection_lost() method will (eventually) be
 | |
|         called with None as its argument.
 | |
|         """
 | |
|         self._force_close(None)
 | |
| 
 | |
|     def _force_close(self, exc):
 | |
|         self._closed = True
 | |
|         if self._ssl_protocol is not None:
 | |
|             self._ssl_protocol._abort(exc)
 | |
| 
 | |
|     def _test__append_write_backlog(self, data):
 | |
|         # for test only
 | |
|         self._ssl_protocol._write_backlog.append(data)
 | |
|         self._ssl_protocol._write_buffer_size += len(data)
 | |
| 
 | |
| 
 | |
| class SSLProtocol(protocols.BufferedProtocol):
 | |
|     max_size = 256 * 1024   # Buffer size passed to read()
 | |
| 
 | |
|     _handshake_start_time = None
 | |
|     _handshake_timeout_handle = None
 | |
|     _shutdown_timeout_handle = None
 | |
| 
 | |
|     def __init__(self, loop, app_protocol, sslcontext, waiter,
 | |
|                  server_side=False, server_hostname=None,
 | |
|                  call_connection_made=True,
 | |
|                  ssl_handshake_timeout=None,
 | |
|                  ssl_shutdown_timeout=None):
 | |
|         if ssl is None:
 | |
|             raise RuntimeError("stdlib ssl module not available")
 | |
| 
 | |
|         self._ssl_buffer = bytearray(self.max_size)
 | |
|         self._ssl_buffer_view = memoryview(self._ssl_buffer)
 | |
| 
 | |
|         if ssl_handshake_timeout is None:
 | |
|             ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
 | |
|         elif ssl_handshake_timeout <= 0:
 | |
|             raise ValueError(
 | |
|                 f"ssl_handshake_timeout should be a positive number, "
 | |
|                 f"got {ssl_handshake_timeout}")
 | |
|         if ssl_shutdown_timeout is None:
 | |
|             ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
 | |
|         elif ssl_shutdown_timeout <= 0:
 | |
|             raise ValueError(
 | |
|                 f"ssl_shutdown_timeout should be a positive number, "
 | |
|                 f"got {ssl_shutdown_timeout}")
 | |
| 
 | |
|         if not sslcontext:
 | |
|             sslcontext = _create_transport_context(
 | |
|                 server_side, server_hostname)
 | |
| 
 | |
|         self._server_side = server_side
 | |
|         if server_hostname and not server_side:
 | |
|             self._server_hostname = server_hostname
 | |
|         else:
 | |
|             self._server_hostname = None
 | |
|         self._sslcontext = sslcontext
 | |
|         # SSL-specific extra info. More info are set when the handshake
 | |
|         # completes.
 | |
|         self._extra = dict(sslcontext=sslcontext)
 | |
| 
 | |
|         # App data write buffering
 | |
|         self._write_backlog = collections.deque()
 | |
|         self._write_buffer_size = 0
 | |
| 
 | |
|         self._waiter = waiter
 | |
|         self._loop = loop
 | |
|         self._set_app_protocol(app_protocol)
 | |
|         self._app_transport = None
 | |
|         self._app_transport_created = False
 | |
|         # transport, ex: SelectorSocketTransport
 | |
|         self._transport = None
 | |
|         self._ssl_handshake_timeout = ssl_handshake_timeout
 | |
|         self._ssl_shutdown_timeout = ssl_shutdown_timeout
 | |
|         # SSL and state machine
 | |
|         self._incoming = ssl.MemoryBIO()
 | |
|         self._outgoing = ssl.MemoryBIO()
 | |
|         self._state = SSLProtocolState.UNWRAPPED
 | |
|         self._conn_lost = 0  # Set when connection_lost called
 | |
|         if call_connection_made:
 | |
|             self._app_state = AppProtocolState.STATE_INIT
 | |
|         else:
 | |
|             self._app_state = AppProtocolState.STATE_CON_MADE
 | |
|         self._sslobj = self._sslcontext.wrap_bio(
 | |
|             self._incoming, self._outgoing,
 | |
|             server_side=self._server_side,
 | |
|             server_hostname=self._server_hostname)
 | |
| 
 | |
|         # Flow Control
 | |
| 
 | |
|         self._ssl_writing_paused = False
 | |
| 
 | |
|         self._app_reading_paused = False
 | |
| 
 | |
|         self._ssl_reading_paused = False
 | |
|         self._incoming_high_water = 0
 | |
|         self._incoming_low_water = 0
 | |
|         self._set_read_buffer_limits()
 | |
|         self._eof_received = False
 | |
| 
 | |
|         self._app_writing_paused = False
 | |
|         self._outgoing_high_water = 0
 | |
|         self._outgoing_low_water = 0
 | |
|         self._set_write_buffer_limits()
 | |
|         self._get_app_transport()
 | |
| 
 | |
|     def _set_app_protocol(self, app_protocol):
 | |
|         self._app_protocol = app_protocol
 | |
|         # Make fast hasattr check first
 | |
|         if (hasattr(app_protocol, 'get_buffer') and
 | |
|                 isinstance(app_protocol, protocols.BufferedProtocol)):
 | |
|             self._app_protocol_get_buffer = app_protocol.get_buffer
 | |
|             self._app_protocol_buffer_updated = app_protocol.buffer_updated
 | |
|             self._app_protocol_is_buffer = True
 | |
|         else:
 | |
|             self._app_protocol_is_buffer = False
 | |
| 
 | |
|     def _wakeup_waiter(self, exc=None):
 | |
|         if self._waiter is None:
 | |
|             return
 | |
|         if not self._waiter.cancelled():
 | |
|             if exc is not None:
 | |
|                 self._waiter.set_exception(exc)
 | |
|             else:
 | |
|                 self._waiter.set_result(None)
 | |
|         self._waiter = None
 | |
| 
 | |
|     def _get_app_transport(self):
 | |
|         if self._app_transport is None:
 | |
|             if self._app_transport_created:
 | |
|                 raise RuntimeError('Creating _SSLProtocolTransport twice')
 | |
|             self._app_transport = _SSLProtocolTransport(self._loop, self)
 | |
|             self._app_transport_created = True
 | |
|         return self._app_transport
 | |
| 
 | |
|     def _is_transport_closing(self):
 | |
|         return self._transport is not None and self._transport.is_closing()
 | |
| 
 | |
|     def connection_made(self, transport):
 | |
|         """Called when the low-level connection is made.
 | |
| 
 | |
|         Start the SSL handshake.
 | |
|         """
 | |
|         self._transport = transport
 | |
|         self._start_handshake()
 | |
| 
 | |
|     def connection_lost(self, exc):
 | |
|         """Called when the low-level connection is lost or closed.
 | |
| 
 | |
|         The argument is an exception object or None (the latter
 | |
|         meaning a regular EOF is received or the connection was
 | |
|         aborted or closed).
 | |
|         """
 | |
|         self._write_backlog.clear()
 | |
|         self._outgoing.read()
 | |
|         self._conn_lost += 1
 | |
| 
 | |
|         # Just mark the app transport as closed so that its __dealloc__
 | |
|         # doesn't complain.
 | |
|         if self._app_transport is not None:
 | |
|             self._app_transport._closed = True
 | |
| 
 | |
|         if self._state != SSLProtocolState.DO_HANDSHAKE:
 | |
|             if (
 | |
|                 self._app_state == AppProtocolState.STATE_CON_MADE or
 | |
|                 self._app_state == AppProtocolState.STATE_EOF
 | |
|             ):
 | |
|                 self._app_state = AppProtocolState.STATE_CON_LOST
 | |
|                 self._loop.call_soon(self._app_protocol.connection_lost, exc)
 | |
|         self._set_state(SSLProtocolState.UNWRAPPED)
 | |
|         self._transport = None
 | |
|         self._app_transport = None
 | |
|         self._app_protocol = None
 | |
|         self._wakeup_waiter(exc)
 | |
| 
 | |
|         if self._shutdown_timeout_handle:
 | |
|             self._shutdown_timeout_handle.cancel()
 | |
|             self._shutdown_timeout_handle = None
 | |
|         if self._handshake_timeout_handle:
 | |
|             self._handshake_timeout_handle.cancel()
 | |
|             self._handshake_timeout_handle = None
 | |
| 
 | |
|     def get_buffer(self, n):
 | |
|         want = n
 | |
|         if want <= 0 or want > self.max_size:
 | |
|             want = self.max_size
 | |
|         if len(self._ssl_buffer) < want:
 | |
|             self._ssl_buffer = bytearray(want)
 | |
|             self._ssl_buffer_view = memoryview(self._ssl_buffer)
 | |
|         return self._ssl_buffer_view
 | |
| 
 | |
|     def buffer_updated(self, nbytes):
 | |
|         self._incoming.write(self._ssl_buffer_view[:nbytes])
 | |
| 
 | |
|         if self._state == SSLProtocolState.DO_HANDSHAKE:
 | |
|             self._do_handshake()
 | |
| 
 | |
|         elif self._state == SSLProtocolState.WRAPPED:
 | |
|             self._do_read()
 | |
| 
 | |
|         elif self._state == SSLProtocolState.FLUSHING:
 | |
|             self._do_flush()
 | |
| 
 | |
|         elif self._state == SSLProtocolState.SHUTDOWN:
 | |
|             self._do_shutdown()
 | |
| 
 | |
|     def eof_received(self):
 | |
|         """Called when the other end of the low-level stream
 | |
|         is half-closed.
 | |
| 
 | |
|         If this returns a false value (including None), the transport
 | |
|         will close itself.  If it returns a true value, closing the
 | |
|         transport is up to the protocol.
 | |
|         """
 | |
|         self._eof_received = True
 | |
|         try:
 | |
|             if self._loop.get_debug():
 | |
|                 logger.debug("%r received EOF", self)
 | |
| 
 | |
|             if self._state == SSLProtocolState.DO_HANDSHAKE:
 | |
|                 self._on_handshake_complete(ConnectionResetError)
 | |
| 
 | |
|             elif self._state == SSLProtocolState.WRAPPED:
 | |
|                 self._set_state(SSLProtocolState.FLUSHING)
 | |
|                 if self._app_reading_paused:
 | |
|                     return True
 | |
|                 else:
 | |
|                     self._do_flush()
 | |
| 
 | |
|             elif self._state == SSLProtocolState.FLUSHING:
 | |
|                 self._do_write()
 | |
|                 self._set_state(SSLProtocolState.SHUTDOWN)
 | |
|                 self._do_shutdown()
 | |
| 
 | |
|             elif self._state == SSLProtocolState.SHUTDOWN:
 | |
|                 self._do_shutdown()
 | |
| 
 | |
|         except Exception:
 | |
|             self._transport.close()
 | |
|             raise
 | |
| 
 | |
|     def _get_extra_info(self, name, default=None):
 | |
|         if name in self._extra:
 | |
|             return self._extra[name]
 | |
|         elif self._transport is not None:
 | |
|             return self._transport.get_extra_info(name, default)
 | |
|         else:
 | |
|             return default
 | |
| 
 | |
|     def _set_state(self, new_state):
 | |
|         allowed = False
 | |
| 
 | |
|         if new_state == SSLProtocolState.UNWRAPPED:
 | |
|             allowed = True
 | |
| 
 | |
|         elif (
 | |
|             self._state == SSLProtocolState.UNWRAPPED and
 | |
|             new_state == SSLProtocolState.DO_HANDSHAKE
 | |
|         ):
 | |
|             allowed = True
 | |
| 
 | |
|         elif (
 | |
|             self._state == SSLProtocolState.DO_HANDSHAKE and
 | |
|             new_state == SSLProtocolState.WRAPPED
 | |
|         ):
 | |
|             allowed = True
 | |
| 
 | |
|         elif (
 | |
|             self._state == SSLProtocolState.WRAPPED and
 | |
|             new_state == SSLProtocolState.FLUSHING
 | |
|         ):
 | |
|             allowed = True
 | |
| 
 | |
|         elif (
 | |
|             self._state == SSLProtocolState.FLUSHING and
 | |
|             new_state == SSLProtocolState.SHUTDOWN
 | |
|         ):
 | |
|             allowed = True
 | |
| 
 | |
|         if allowed:
 | |
|             self._state = new_state
 | |
| 
 | |
|         else:
 | |
|             raise RuntimeError(
 | |
|                 'cannot switch state from {} to {}'.format(
 | |
|                     self._state, new_state))
 | |
| 
 | |
|     # Handshake flow
 | |
| 
 | |
|     def _start_handshake(self):
 | |
|         if self._loop.get_debug():
 | |
|             logger.debug("%r starts SSL handshake", self)
 | |
|             self._handshake_start_time = self._loop.time()
 | |
|         else:
 | |
|             self._handshake_start_time = None
 | |
| 
 | |
|         self._set_state(SSLProtocolState.DO_HANDSHAKE)
 | |
| 
 | |
|         # start handshake timeout count down
 | |
|         self._handshake_timeout_handle = \
 | |
|             self._loop.call_later(self._ssl_handshake_timeout,
 | |
|                                   self._check_handshake_timeout)
 | |
| 
 | |
|         self._do_handshake()
 | |
| 
 | |
|     def _check_handshake_timeout(self):
 | |
|         if self._state == SSLProtocolState.DO_HANDSHAKE:
 | |
|             msg = (
 | |
|                 f"SSL handshake is taking longer than "
 | |
|                 f"{self._ssl_handshake_timeout} seconds: "
 | |
|                 f"aborting the connection"
 | |
|             )
 | |
|             self._fatal_error(ConnectionAbortedError(msg))
 | |
| 
 | |
|     def _do_handshake(self):
 | |
|         try:
 | |
|             self._sslobj.do_handshake()
 | |
|         except SSLAgainErrors:
 | |
|             self._process_outgoing()
 | |
|         except ssl.SSLError as exc:
 | |
|             self._on_handshake_complete(exc)
 | |
|         else:
 | |
|             self._on_handshake_complete(None)
 | |
| 
 | |
|     def _on_handshake_complete(self, handshake_exc):
 | |
|         if self._handshake_timeout_handle is not None:
 | |
|             self._handshake_timeout_handle.cancel()
 | |
|             self._handshake_timeout_handle = None
 | |
| 
 | |
|         sslobj = self._sslobj
 | |
|         try:
 | |
|             if handshake_exc is None:
 | |
|                 self._set_state(SSLProtocolState.WRAPPED)
 | |
|             else:
 | |
|                 raise handshake_exc
 | |
| 
 | |
|             peercert = sslobj.getpeercert()
 | |
|         except Exception as exc:
 | |
|             handshake_exc = None
 | |
|             self._set_state(SSLProtocolState.UNWRAPPED)
 | |
|             if isinstance(exc, ssl.CertificateError):
 | |
|                 msg = 'SSL handshake failed on verifying the certificate'
 | |
|             else:
 | |
|                 msg = 'SSL handshake failed'
 | |
|             self._fatal_error(exc, msg)
 | |
|             self._wakeup_waiter(exc)
 | |
|             return
 | |
| 
 | |
|         if self._loop.get_debug():
 | |
|             dt = self._loop.time() - self._handshake_start_time
 | |
|             logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
 | |
| 
 | |
|         # Add extra info that becomes available after handshake.
 | |
|         self._extra.update(peercert=peercert,
 | |
|                            cipher=sslobj.cipher(),
 | |
|                            compression=sslobj.compression(),
 | |
|                            ssl_object=sslobj)
 | |
|         if self._app_state == AppProtocolState.STATE_INIT:
 | |
|             self._app_state = AppProtocolState.STATE_CON_MADE
 | |
|             self._app_protocol.connection_made(self._get_app_transport())
 | |
|         self._wakeup_waiter()
 | |
|         self._do_read()
 | |
| 
 | |
|     # Shutdown flow
 | |
| 
 | |
|     def _start_shutdown(self):
 | |
|         if (
 | |
|             self._state in (
 | |
|                 SSLProtocolState.FLUSHING,
 | |
|                 SSLProtocolState.SHUTDOWN,
 | |
|                 SSLProtocolState.UNWRAPPED
 | |
|             )
 | |
|         ):
 | |
|             return
 | |
|         if self._app_transport is not None:
 | |
|             self._app_transport._closed = True
 | |
|         if self._state == SSLProtocolState.DO_HANDSHAKE:
 | |
|             self._abort(None)
 | |
|         else:
 | |
|             self._set_state(SSLProtocolState.FLUSHING)
 | |
|             self._shutdown_timeout_handle = self._loop.call_later(
 | |
|                 self._ssl_shutdown_timeout,
 | |
|                 self._check_shutdown_timeout
 | |
|             )
 | |
|             self._do_flush()
 | |
| 
 | |
|     def _check_shutdown_timeout(self):
 | |
|         if (
 | |
|             self._state in (
 | |
|                 SSLProtocolState.FLUSHING,
 | |
|                 SSLProtocolState.SHUTDOWN
 | |
|             )
 | |
|         ):
 | |
|             self._transport._force_close(
 | |
|                 exceptions.TimeoutError('SSL shutdown timed out'))
 | |
| 
 | |
|     def _do_flush(self):
 | |
|         self._do_read()
 | |
|         self._set_state(SSLProtocolState.SHUTDOWN)
 | |
|         self._do_shutdown()
 | |
| 
 | |
|     def _do_shutdown(self):
 | |
|         try:
 | |
|             if not self._eof_received:
 | |
|                 self._sslobj.unwrap()
 | |
|         except SSLAgainErrors:
 | |
|             self._process_outgoing()
 | |
|         except ssl.SSLError as exc:
 | |
|             self._on_shutdown_complete(exc)
 | |
|         else:
 | |
|             self._process_outgoing()
 | |
|             self._call_eof_received()
 | |
|             self._on_shutdown_complete(None)
 | |
| 
 | |
|     def _on_shutdown_complete(self, shutdown_exc):
 | |
|         if self._shutdown_timeout_handle is not None:
 | |
|             self._shutdown_timeout_handle.cancel()
 | |
|             self._shutdown_timeout_handle = None
 | |
| 
 | |
|         if shutdown_exc:
 | |
|             self._fatal_error(shutdown_exc)
 | |
|         else:
 | |
|             self._loop.call_soon(self._transport.close)
 | |
| 
 | |
|     def _abort(self, exc):
 | |
|         self._set_state(SSLProtocolState.UNWRAPPED)
 | |
|         if self._transport is not None:
 | |
|             self._transport._force_close(exc)
 | |
| 
 | |
|     # Outgoing flow
 | |
| 
 | |
|     def _write_appdata(self, list_of_data):
 | |
|         if (
 | |
|             self._state in (
 | |
|                 SSLProtocolState.FLUSHING,
 | |
|                 SSLProtocolState.SHUTDOWN,
 | |
|                 SSLProtocolState.UNWRAPPED
 | |
|             )
 | |
|         ):
 | |
|             if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
 | |
|                 logger.warning('SSL connection is closed')
 | |
|             self._conn_lost += 1
 | |
|             return
 | |
| 
 | |
|         for data in list_of_data:
 | |
|             self._write_backlog.append(data)
 | |
|             self._write_buffer_size += len(data)
 | |
| 
 | |
|         try:
 | |
|             if self._state == SSLProtocolState.WRAPPED:
 | |
|                 self._do_write()
 | |
| 
 | |
|         except Exception as ex:
 | |
|             self._fatal_error(ex, 'Fatal error on SSL protocol')
 | |
| 
 | |
|     def _do_write(self):
 | |
|         try:
 | |
|             while self._write_backlog:
 | |
|                 data = self._write_backlog[0]
 | |
|                 count = self._sslobj.write(data)
 | |
|                 data_len = len(data)
 | |
|                 if count < data_len:
 | |
|                     self._write_backlog[0] = data[count:]
 | |
|                     self._write_buffer_size -= count
 | |
|                 else:
 | |
|                     del self._write_backlog[0]
 | |
|                     self._write_buffer_size -= data_len
 | |
|         except SSLAgainErrors:
 | |
|             pass
 | |
|         self._process_outgoing()
 | |
| 
 | |
|     def _process_outgoing(self):
 | |
|         if not self._ssl_writing_paused:
 | |
|             data = self._outgoing.read()
 | |
|             if len(data):
 | |
|                 self._transport.write(data)
 | |
|         self._control_app_writing()
 | |
| 
 | |
|     # Incoming flow
 | |
| 
 | |
|     def _do_read(self):
 | |
|         if (
 | |
|             self._state not in (
 | |
|                 SSLProtocolState.WRAPPED,
 | |
|                 SSLProtocolState.FLUSHING,
 | |
|             )
 | |
|         ):
 | |
|             return
 | |
|         try:
 | |
|             if not self._app_reading_paused:
 | |
|                 if self._app_protocol_is_buffer:
 | |
|                     self._do_read__buffered()
 | |
|                 else:
 | |
|                     self._do_read__copied()
 | |
|                 if self._write_backlog:
 | |
|                     self._do_write()
 | |
|                 else:
 | |
|                     self._process_outgoing()
 | |
|             self._control_ssl_reading()
 | |
|         except Exception as ex:
 | |
|             self._fatal_error(ex, 'Fatal error on SSL protocol')
 | |
| 
 | |
|     def _do_read__buffered(self):
 | |
|         offset = 0
 | |
|         count = 1
 | |
| 
 | |
|         buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
 | |
|         wants = len(buf)
 | |
| 
 | |
|         try:
 | |
|             count = self._sslobj.read(wants, buf)
 | |
| 
 | |
|             if count > 0:
 | |
|                 offset = count
 | |
|                 while offset < wants:
 | |
|                     count = self._sslobj.read(wants - offset, buf[offset:])
 | |
|                     if count > 0:
 | |
|                         offset += count
 | |
|                     else:
 | |
|                         break
 | |
|                 else:
 | |
|                     self._loop.call_soon(self._do_read)
 | |
|         except SSLAgainErrors:
 | |
|             pass
 | |
|         if offset > 0:
 | |
|             self._app_protocol_buffer_updated(offset)
 | |
|         if not count:
 | |
|             # close_notify
 | |
|             self._call_eof_received()
 | |
|             self._start_shutdown()
 | |
| 
 | |
|     def _do_read__copied(self):
 | |
|         chunk = b'1'
 | |
|         zero = True
 | |
|         one = False
 | |
| 
 | |
|         try:
 | |
|             while True:
 | |
|                 chunk = self._sslobj.read(self.max_size)
 | |
|                 if not chunk:
 | |
|                     break
 | |
|                 if zero:
 | |
|                     zero = False
 | |
|                     one = True
 | |
|                     first = chunk
 | |
|                 elif one:
 | |
|                     one = False
 | |
|                     data = [first, chunk]
 | |
|                 else:
 | |
|                     data.append(chunk)
 | |
|         except SSLAgainErrors:
 | |
|             pass
 | |
|         if one:
 | |
|             self._app_protocol.data_received(first)
 | |
|         elif not zero:
 | |
|             self._app_protocol.data_received(b''.join(data))
 | |
|         if not chunk:
 | |
|             # close_notify
 | |
|             self._call_eof_received()
 | |
|             self._start_shutdown()
 | |
| 
 | |
|     def _call_eof_received(self):
 | |
|         try:
 | |
|             if self._app_state == AppProtocolState.STATE_CON_MADE:
 | |
|                 self._app_state = AppProtocolState.STATE_EOF
 | |
|                 keep_open = self._app_protocol.eof_received()
 | |
|                 if keep_open:
 | |
|                     logger.warning('returning true from eof_received() '
 | |
|                                    'has no effect when using ssl')
 | |
|         except (KeyboardInterrupt, SystemExit):
 | |
|             raise
 | |
|         except BaseException as ex:
 | |
|             self._fatal_error(ex, 'Error calling eof_received()')
 | |
| 
 | |
|     # Flow control for writes from APP socket
 | |
| 
 | |
|     def _control_app_writing(self):
 | |
|         size = self._get_write_buffer_size()
 | |
|         if size >= self._outgoing_high_water and not self._app_writing_paused:
 | |
|             self._app_writing_paused = True
 | |
|             try:
 | |
|                 self._app_protocol.pause_writing()
 | |
|             except (KeyboardInterrupt, SystemExit):
 | |
|                 raise
 | |
|             except BaseException as exc:
 | |
|                 self._loop.call_exception_handler({
 | |
|                     'message': 'protocol.pause_writing() failed',
 | |
|                     'exception': exc,
 | |
|                     'transport': self._app_transport,
 | |
|                     'protocol': self,
 | |
|                 })
 | |
|         elif size <= self._outgoing_low_water and self._app_writing_paused:
 | |
|             self._app_writing_paused = False
 | |
|             try:
 | |
|                 self._app_protocol.resume_writing()
 | |
|             except (KeyboardInterrupt, SystemExit):
 | |
|                 raise
 | |
|             except BaseException as exc:
 | |
|                 self._loop.call_exception_handler({
 | |
|                     'message': 'protocol.resume_writing() failed',
 | |
|                     'exception': exc,
 | |
|                     'transport': self._app_transport,
 | |
|                     'protocol': self,
 | |
|                 })
 | |
| 
 | |
|     def _get_write_buffer_size(self):
 | |
|         return self._outgoing.pending + self._write_buffer_size
 | |
| 
 | |
|     def _set_write_buffer_limits(self, high=None, low=None):
 | |
|         high, low = add_flowcontrol_defaults(
 | |
|             high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
 | |
|         self._outgoing_high_water = high
 | |
|         self._outgoing_low_water = low
 | |
| 
 | |
|     # Flow control for reads to APP socket
 | |
| 
 | |
|     def _pause_reading(self):
 | |
|         self._app_reading_paused = True
 | |
| 
 | |
|     def _resume_reading(self):
 | |
|         if self._app_reading_paused:
 | |
|             self._app_reading_paused = False
 | |
| 
 | |
|             def resume():
 | |
|                 if self._state == SSLProtocolState.WRAPPED:
 | |
|                     self._do_read()
 | |
|                 elif self._state == SSLProtocolState.FLUSHING:
 | |
|                     self._do_flush()
 | |
|                 elif self._state == SSLProtocolState.SHUTDOWN:
 | |
|                     self._do_shutdown()
 | |
|             self._loop.call_soon(resume)
 | |
| 
 | |
|     # Flow control for reads from SSL socket
 | |
| 
 | |
|     def _control_ssl_reading(self):
 | |
|         size = self._get_read_buffer_size()
 | |
|         if size >= self._incoming_high_water and not self._ssl_reading_paused:
 | |
|             self._ssl_reading_paused = True
 | |
|             self._transport.pause_reading()
 | |
|         elif size <= self._incoming_low_water and self._ssl_reading_paused:
 | |
|             self._ssl_reading_paused = False
 | |
|             self._transport.resume_reading()
 | |
| 
 | |
|     def _set_read_buffer_limits(self, high=None, low=None):
 | |
|         high, low = add_flowcontrol_defaults(
 | |
|             high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
 | |
|         self._incoming_high_water = high
 | |
|         self._incoming_low_water = low
 | |
| 
 | |
|     def _get_read_buffer_size(self):
 | |
|         return self._incoming.pending
 | |
| 
 | |
|     # Flow control for writes to SSL socket
 | |
| 
 | |
|     def pause_writing(self):
 | |
|         """Called when the low-level transport's buffer goes over
 | |
|         the high-water mark.
 | |
|         """
 | |
|         assert not self._ssl_writing_paused
 | |
|         self._ssl_writing_paused = True
 | |
| 
 | |
|     def resume_writing(self):
 | |
|         """Called when the low-level transport's buffer drains below
 | |
|         the low-water mark.
 | |
|         """
 | |
|         assert self._ssl_writing_paused
 | |
|         self._ssl_writing_paused = False
 | |
|         self._process_outgoing()
 | |
| 
 | |
|     def _fatal_error(self, exc, message='Fatal error on transport'):
 | |
|         if self._transport:
 | |
|             self._transport._force_close(exc)
 | |
| 
 | |
|         if isinstance(exc, OSError):
 | |
|             if self._loop.get_debug():
 | |
|                 logger.debug("%r: %s", self, message, exc_info=True)
 | |
|         elif not isinstance(exc, exceptions.CancelledError):
 | |
|             self._loop.call_exception_handler({
 | |
|                 'message': message,
 | |
|                 'exception': exc,
 | |
|                 'transport': self._transport,
 | |
|                 'protocol': self,
 | |
|             })
 |