mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 11:14:33 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			473 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			473 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Stream-related things."""
 | |
| 
 | |
| __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
 | |
|            'open_connection', 'start_server',
 | |
|            'IncompleteReadError',
 | |
|            ]
 | |
| 
 | |
| import socket
 | |
| 
 | |
| if hasattr(socket, 'AF_UNIX'):
 | |
|     __all__.extend(['open_unix_connection', 'start_unix_server'])
 | |
| 
 | |
| from . import events
 | |
| from . import futures
 | |
| from . import protocols
 | |
| from . import tasks
 | |
| 
 | |
| 
 | |
| _DEFAULT_LIMIT = 2**16
 | |
| 
 | |
| 
 | |
| class IncompleteReadError(EOFError):
 | |
|     """
 | |
|     Incomplete read error. Attributes:
 | |
| 
 | |
|     - partial: read bytes string before the end of stream was reached
 | |
|     - expected: total number of expected bytes
 | |
|     """
 | |
|     def __init__(self, partial, expected):
 | |
|         EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
 | |
|                                 % (len(partial), expected))
 | |
|         self.partial = partial
 | |
|         self.expected = expected
 | |
| 
 | |
| 
 | |
| @tasks.coroutine
 | |
| def open_connection(host=None, port=None, *,
 | |
|                     loop=None, limit=_DEFAULT_LIMIT, **kwds):
 | |
|     """A wrapper for create_connection() returning a (reader, writer) pair.
 | |
| 
 | |
|     The reader returned is a StreamReader instance; the writer is a
 | |
|     StreamWriter instance.
 | |
| 
 | |
|     The arguments are all the usual arguments to create_connection()
 | |
|     except protocol_factory; most common are positional host and port,
 | |
|     with various optional keyword arguments following.
 | |
| 
 | |
|     Additional optional keyword arguments are loop (to set the event loop
 | |
|     instance to use) and limit (to set the buffer limit passed to the
 | |
|     StreamReader).
 | |
| 
 | |
|     (If you want to customize the StreamReader and/or
 | |
|     StreamReaderProtocol classes, just copy the code -- there's
 | |
|     really nothing special here except some convenience.)
 | |
|     """
 | |
|     if loop is None:
 | |
|         loop = events.get_event_loop()
 | |
|     reader = StreamReader(limit=limit, loop=loop)
 | |
|     protocol = StreamReaderProtocol(reader, loop=loop)
 | |
|     transport, _ = yield from loop.create_connection(
 | |
|         lambda: protocol, host, port, **kwds)
 | |
|     writer = StreamWriter(transport, protocol, reader, loop)
 | |
|     return reader, writer
 | |
| 
 | |
| 
 | |
| @tasks.coroutine
 | |
| def start_server(client_connected_cb, host=None, port=None, *,
 | |
|                  loop=None, limit=_DEFAULT_LIMIT, **kwds):
 | |
|     """Start a socket server, call back for each client connected.
 | |
| 
 | |
|     The first parameter, `client_connected_cb`, takes two parameters:
 | |
|     client_reader, client_writer.  client_reader is a StreamReader
 | |
|     object, while client_writer is a StreamWriter object.  This
 | |
|     parameter can either be a plain callback function or a coroutine;
 | |
|     if it is a coroutine, it will be automatically converted into a
 | |
|     Task.
 | |
| 
 | |
|     The rest of the arguments are all the usual arguments to
 | |
|     loop.create_server() except protocol_factory; most common are
 | |
|     positional host and port, with various optional keyword arguments
 | |
|     following.  The return value is the same as loop.create_server().
 | |
| 
 | |
|     Additional optional keyword arguments are loop (to set the event loop
 | |
|     instance to use) and limit (to set the buffer limit passed to the
 | |
|     StreamReader).
 | |
| 
 | |
|     The return value is the same as loop.create_server(), i.e. a
 | |
|     Server object which can be used to stop the service.
 | |
|     """
 | |
|     if loop is None:
 | |
|         loop = events.get_event_loop()
 | |
| 
 | |
|     def factory():
 | |
|         reader = StreamReader(limit=limit, loop=loop)
 | |
|         protocol = StreamReaderProtocol(reader, client_connected_cb,
 | |
|                                         loop=loop)
 | |
|         return protocol
 | |
| 
 | |
|     return (yield from loop.create_server(factory, host, port, **kwds))
 | |
| 
 | |
| 
 | |
| if hasattr(socket, 'AF_UNIX'):
 | |
|     # UNIX Domain Sockets are supported on this platform
 | |
| 
 | |
|     @tasks.coroutine
 | |
|     def open_unix_connection(path=None, *,
 | |
|                              loop=None, limit=_DEFAULT_LIMIT, **kwds):
 | |
|         """Similar to `open_connection` but works with UNIX Domain Sockets."""
 | |
|         if loop is None:
 | |
|             loop = events.get_event_loop()
 | |
|         reader = StreamReader(limit=limit, loop=loop)
 | |
|         protocol = StreamReaderProtocol(reader, loop=loop)
 | |
|         transport, _ = yield from loop.create_unix_connection(
 | |
|             lambda: protocol, path, **kwds)
 | |
|         writer = StreamWriter(transport, protocol, reader, loop)
 | |
|         return reader, writer
 | |
| 
 | |
| 
 | |
|     @tasks.coroutine
 | |
|     def start_unix_server(client_connected_cb, path=None, *,
 | |
|                           loop=None, limit=_DEFAULT_LIMIT, **kwds):
 | |
|         """Similar to `start_server` but works with UNIX Domain Sockets."""
 | |
|         if loop is None:
 | |
|             loop = events.get_event_loop()
 | |
| 
 | |
|         def factory():
 | |
|             reader = StreamReader(limit=limit, loop=loop)
 | |
|             protocol = StreamReaderProtocol(reader, client_connected_cb,
 | |
|                                             loop=loop)
 | |
|             return protocol
 | |
| 
 | |
|         return (yield from loop.create_unix_server(factory, path, **kwds))
 | |
| 
 | |
| 
 | |
| class FlowControlMixin(protocols.Protocol):
 | |
|     """Reusable flow control logic for StreamWriter.drain().
 | |
| 
 | |
|     This implements the protocol methods pause_writing(),
 | |
|     resume_reading() and connection_lost().  If the subclass overrides
 | |
|     these it must call the super methods.
 | |
| 
 | |
|     StreamWriter.drain() must check for error conditions and then call
 | |
|     _make_drain_waiter(), which will return either () or a Future
 | |
|     depending on the paused state.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, loop=None):
 | |
|         self._loop = loop  # May be None; we may never need it.
 | |
|         self._paused = False
 | |
|         self._drain_waiter = None
 | |
| 
 | |
|     def pause_writing(self):
 | |
|         assert not self._paused
 | |
|         self._paused = True
 | |
| 
 | |
|     def resume_writing(self):
 | |
|         assert self._paused
 | |
|         self._paused = False
 | |
|         waiter = self._drain_waiter
 | |
|         if waiter is not None:
 | |
|             self._drain_waiter = None
 | |
|             if not waiter.done():
 | |
|                 waiter.set_result(None)
 | |
| 
 | |
|     def connection_lost(self, exc):
 | |
|         # Wake up the writer if currently paused.
 | |
|         if not self._paused:
 | |
|             return
 | |
|         waiter = self._drain_waiter
 | |
|         if waiter is None:
 | |
|             return
 | |
|         self._drain_waiter = None
 | |
|         if waiter.done():
 | |
|             return
 | |
|         if exc is None:
 | |
|             waiter.set_result(None)
 | |
|         else:
 | |
|             waiter.set_exception(exc)
 | |
| 
 | |
|     def _make_drain_waiter(self):
 | |
|         if not self._paused:
 | |
|             return ()
 | |
|         waiter = self._drain_waiter
 | |
|         assert waiter is None or waiter.cancelled()
 | |
|         waiter = futures.Future(loop=self._loop)
 | |
|         self._drain_waiter = waiter
 | |
|         return waiter
 | |
| 
 | |
| 
 | |
| class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
 | |
|     """Helper class to adapt between Protocol and StreamReader.
 | |
| 
 | |
|     (This is a helper class instead of making StreamReader itself a
 | |
|     Protocol subclass, because the StreamReader has other potential
 | |
|     uses, and to prevent the user of the StreamReader to accidentally
 | |
|     call inappropriate methods of the protocol.)
 | |
|     """
 | |
| 
 | |
|     def __init__(self, stream_reader, client_connected_cb=None, loop=None):
 | |
|         super().__init__(loop=loop)
 | |
|         self._stream_reader = stream_reader
 | |
|         self._stream_writer = None
 | |
|         self._client_connected_cb = client_connected_cb
 | |
| 
 | |
|     def connection_made(self, transport):
 | |
|         self._stream_reader.set_transport(transport)
 | |
|         if self._client_connected_cb is not None:
 | |
|             self._stream_writer = StreamWriter(transport, self,
 | |
|                                                self._stream_reader,
 | |
|                                                self._loop)
 | |
|             res = self._client_connected_cb(self._stream_reader,
 | |
|                                             self._stream_writer)
 | |
|             if tasks.iscoroutine(res):
 | |
|                 tasks.Task(res, loop=self._loop)
 | |
| 
 | |
|     def connection_lost(self, exc):
 | |
|         if exc is None:
 | |
|             self._stream_reader.feed_eof()
 | |
|         else:
 | |
|             self._stream_reader.set_exception(exc)
 | |
|         super().connection_lost(exc)
 | |
| 
 | |
|     def data_received(self, data):
 | |
|         self._stream_reader.feed_data(data)
 | |
| 
 | |
|     def eof_received(self):
 | |
|         self._stream_reader.feed_eof()
 | |
| 
 | |
| 
 | |
| class StreamWriter:
 | |
|     """Wraps a Transport.
 | |
| 
 | |
|     This exposes write(), writelines(), [can_]write_eof(),
 | |
|     get_extra_info() and close().  It adds drain() which returns an
 | |
|     optional Future on which you can wait for flow control.  It also
 | |
|     adds a transport property which references the Transport
 | |
|     directly.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, transport, protocol, reader, loop):
 | |
|         self._transport = transport
 | |
|         self._protocol = protocol
 | |
|         self._reader = reader
 | |
|         self._loop = loop
 | |
| 
 | |
|     @property
 | |
|     def transport(self):
 | |
|         return self._transport
 | |
| 
 | |
|     def write(self, data):
 | |
|         self._transport.write(data)
 | |
| 
 | |
|     def writelines(self, data):
 | |
|         self._transport.writelines(data)
 | |
| 
 | |
|     def write_eof(self):
 | |
|         return self._transport.write_eof()
 | |
| 
 | |
|     def can_write_eof(self):
 | |
|         return self._transport.can_write_eof()
 | |
| 
 | |
|     def close(self):
 | |
|         return self._transport.close()
 | |
| 
 | |
|     def get_extra_info(self, name, default=None):
 | |
|         return self._transport.get_extra_info(name, default)
 | |
| 
 | |
|     def drain(self):
 | |
|         """This method has an unusual return value.
 | |
| 
 | |
|         The intended use is to write
 | |
| 
 | |
|           w.write(data)
 | |
|           yield from w.drain()
 | |
| 
 | |
|         When there's nothing to wait for, drain() returns (), and the
 | |
|         yield-from continues immediately.  When the transport buffer
 | |
|         is full (the protocol is paused), drain() creates and returns
 | |
|         a Future and the yield-from will block until that Future is
 | |
|         completed, which will happen when the buffer is (partially)
 | |
|         drained and the protocol is resumed.
 | |
|         """
 | |
|         if self._reader is not None and self._reader._exception is not None:
 | |
|             raise self._reader._exception
 | |
|         if self._transport._conn_lost:  # Uses private variable.
 | |
|             raise ConnectionResetError('Connection lost')
 | |
|         return self._protocol._make_drain_waiter()
 | |
| 
 | |
| 
 | |
| class StreamReader:
 | |
| 
 | |
|     def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
 | |
|         # The line length limit is  a security feature;
 | |
|         # it also doubles as half the buffer limit.
 | |
|         self._limit = limit
 | |
|         if loop is None:
 | |
|             loop = events.get_event_loop()
 | |
|         self._loop = loop
 | |
|         self._buffer = bytearray()
 | |
|         self._eof = False  # Whether we're done.
 | |
|         self._waiter = None  # A future.
 | |
|         self._exception = None
 | |
|         self._transport = None
 | |
|         self._paused = False
 | |
| 
 | |
|     def exception(self):
 | |
|         return self._exception
 | |
| 
 | |
|     def set_exception(self, exc):
 | |
|         self._exception = exc
 | |
| 
 | |
|         waiter = self._waiter
 | |
|         if waiter is not None:
 | |
|             self._waiter = None
 | |
|             if not waiter.cancelled():
 | |
|                 waiter.set_exception(exc)
 | |
| 
 | |
|     def set_transport(self, transport):
 | |
|         assert self._transport is None, 'Transport already set'
 | |
|         self._transport = transport
 | |
| 
 | |
|     def _maybe_resume_transport(self):
 | |
|         if self._paused and len(self._buffer) <= self._limit:
 | |
|             self._paused = False
 | |
|             self._transport.resume_reading()
 | |
| 
 | |
|     def feed_eof(self):
 | |
|         self._eof = True
 | |
|         waiter = self._waiter
 | |
|         if waiter is not None:
 | |
|             self._waiter = None
 | |
|             if not waiter.cancelled():
 | |
|                 waiter.set_result(True)
 | |
| 
 | |
|     def at_eof(self):
 | |
|         """Return True if the buffer is empty and 'feed_eof' was called."""
 | |
|         return self._eof and not self._buffer
 | |
| 
 | |
|     def feed_data(self, data):
 | |
|         assert not self._eof, 'feed_data after feed_eof'
 | |
| 
 | |
|         if not data:
 | |
|             return
 | |
| 
 | |
|         self._buffer.extend(data)
 | |
| 
 | |
|         waiter = self._waiter
 | |
|         if waiter is not None:
 | |
|             self._waiter = None
 | |
|             if not waiter.cancelled():
 | |
|                 waiter.set_result(False)
 | |
| 
 | |
|         if (self._transport is not None and
 | |
|             not self._paused and
 | |
|             len(self._buffer) > 2*self._limit):
 | |
|             try:
 | |
|                 self._transport.pause_reading()
 | |
|             except NotImplementedError:
 | |
|                 # The transport can't be paused.
 | |
|                 # We'll just have to buffer all data.
 | |
|                 # Forget the transport so we don't keep trying.
 | |
|                 self._transport = None
 | |
|             else:
 | |
|                 self._paused = True
 | |
| 
 | |
|     def _create_waiter(self, func_name):
 | |
|         # StreamReader uses a future to link the protocol feed_data() method
 | |
|         # to a read coroutine. Running two read coroutines at the same time
 | |
|         # would have an unexpected behaviour. It would not possible to know
 | |
|         # which coroutine would get the next data.
 | |
|         if self._waiter is not None:
 | |
|             raise RuntimeError('%s() called while another coroutine is '
 | |
|                                'already waiting for incoming data' % func_name)
 | |
|         return futures.Future(loop=self._loop)
 | |
| 
 | |
|     @tasks.coroutine
 | |
|     def readline(self):
 | |
|         if self._exception is not None:
 | |
|             raise self._exception
 | |
| 
 | |
|         line = bytearray()
 | |
|         not_enough = True
 | |
| 
 | |
|         while not_enough:
 | |
|             while self._buffer and not_enough:
 | |
|                 ichar = self._buffer.find(b'\n')
 | |
|                 if ichar < 0:
 | |
|                     line.extend(self._buffer)
 | |
|                     self._buffer.clear()
 | |
|                 else:
 | |
|                     ichar += 1
 | |
|                     line.extend(self._buffer[:ichar])
 | |
|                     del self._buffer[:ichar]
 | |
|                     not_enough = False
 | |
| 
 | |
|                 if len(line) > self._limit:
 | |
|                     self._maybe_resume_transport()
 | |
|                     raise ValueError('Line is too long')
 | |
| 
 | |
|             if self._eof:
 | |
|                 break
 | |
| 
 | |
|             if not_enough:
 | |
|                 self._waiter = self._create_waiter('readline')
 | |
|                 try:
 | |
|                     yield from self._waiter
 | |
|                 finally:
 | |
|                     self._waiter = None
 | |
| 
 | |
|         self._maybe_resume_transport()
 | |
|         return bytes(line)
 | |
| 
 | |
|     @tasks.coroutine
 | |
|     def read(self, n=-1):
 | |
|         if self._exception is not None:
 | |
|             raise self._exception
 | |
| 
 | |
|         if not n:
 | |
|             return b''
 | |
| 
 | |
|         if n < 0:
 | |
|             # This used to just loop creating a new waiter hoping to
 | |
|             # collect everything in self._buffer, but that would
 | |
|             # deadlock if the subprocess sends more than self.limit
 | |
|             # bytes.  So just call self.read(self._limit) until EOF.
 | |
|             blocks = []
 | |
|             while True:
 | |
|                 block = yield from self.read(self._limit)
 | |
|                 if not block:
 | |
|                     break
 | |
|                 blocks.append(block)
 | |
|             return b''.join(blocks)
 | |
|         else:
 | |
|             if not self._buffer and not self._eof:
 | |
|                 self._waiter = self._create_waiter('read')
 | |
|                 try:
 | |
|                     yield from self._waiter
 | |
|                 finally:
 | |
|                     self._waiter = None
 | |
| 
 | |
|         if n < 0 or len(self._buffer) <= n:
 | |
|             data = bytes(self._buffer)
 | |
|             self._buffer.clear()
 | |
|         else:
 | |
|             # n > 0 and len(self._buffer) > n
 | |
|             data = bytes(self._buffer[:n])
 | |
|             del self._buffer[:n]
 | |
| 
 | |
|         self._maybe_resume_transport()
 | |
|         return data
 | |
| 
 | |
|     @tasks.coroutine
 | |
|     def readexactly(self, n):
 | |
|         if self._exception is not None:
 | |
|             raise self._exception
 | |
| 
 | |
|         # There used to be "optimized" code here.  It created its own
 | |
|         # Future and waited until self._buffer had at least the n
 | |
|         # bytes, then called read(n).  Unfortunately, this could pause
 | |
|         # the transport if the argument was larger than the pause
 | |
|         # limit (which is twice self._limit).  So now we just read()
 | |
|         # into a local buffer.
 | |
| 
 | |
|         blocks = []
 | |
|         while n > 0:
 | |
|             block = yield from self.read(n)
 | |
|             if not block:
 | |
|                 partial = b''.join(blocks)
 | |
|                 raise IncompleteReadError(partial, len(partial) + n)
 | |
|             blocks.append(block)
 | |
|             n -= len(block)
 | |
| 
 | |
|         return b''.join(blocks)
 | 
