mirror of
				https://github.com/python/cpython.git
				synced 2025-11-01 06:01:29 +00:00 
			
		
		
		
	asyncio: Add streams.start_server(), by Gustavo Carneiro.
This commit is contained in:
		
							parent
							
								
									4a9ee26750
								
							
						
					
					
						commit
						1540b16ff4
					
				
					 2 changed files with 117 additions and 2 deletions
				
			
		|  | @ -1,6 +1,8 @@ | |||
| """Stream-related things.""" | ||||
| 
 | ||||
| __all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] | ||||
| __all__ = ['StreamReader', 'StreamReaderProtocol', | ||||
|            'open_connection', 'start_server', | ||||
|            ] | ||||
| 
 | ||||
| import collections | ||||
| 
 | ||||
|  | @ -43,6 +45,42 @@ def open_connection(host=None, port=None, *, | |||
|     return reader, writer | ||||
| 
 | ||||
| 
 | ||||
| @tasks.coroutine | ||||
| def start_server(client_connected_cb, host=None, port=None, *, | ||||
|                  loop=None, limit=_DEFAULT_LIMIT, **kwds): | ||||
|     """Start a socket server, call back for each client connected. | ||||
| 
 | ||||
|     The first parameter, `client_connected_cb`, takes two parameters: | ||||
|     client_reader, client_writer.  client_reader is a StreamReader | ||||
|     object, while client_writer is a StreamWriter object.  This | ||||
|     parameter can either be a plain callback function or a coroutine; | ||||
|     if it is a coroutine, it will be automatically converted into a | ||||
|     Task. | ||||
| 
 | ||||
|     The rest of the arguments are all the usual arguments to | ||||
|     loop.create_server() except protocol_factory; most common are | ||||
|     positional host and port, with various optional keyword arguments | ||||
|     following.  The return value is the same as loop.create_server(). | ||||
| 
 | ||||
|     Additional optional keyword arguments are loop (to set the event loop | ||||
|     instance to use) and limit (to set the buffer limit passed to the | ||||
|     StreamReader). | ||||
| 
 | ||||
|     The return value is the same as loop.create_server(), i.e. a | ||||
|     Server object which can be used to stop the service. | ||||
|     """ | ||||
|     if loop is None: | ||||
|         loop = events.get_event_loop() | ||||
| 
 | ||||
|     def factory(): | ||||
|         reader = StreamReader(limit=limit, loop=loop) | ||||
|         protocol = StreamReaderProtocol(reader, client_connected_cb, | ||||
|                                         loop=loop) | ||||
|         return protocol | ||||
| 
 | ||||
|     return (yield from loop.create_server(factory, host, port, **kwds)) | ||||
| 
 | ||||
| 
 | ||||
| class StreamReaderProtocol(protocols.Protocol): | ||||
|     """Trivial helper class to adapt between Protocol and StreamReader. | ||||
| 
 | ||||
|  | @ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol): | |||
|     call inappropriate methods of the protocol.) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, stream_reader): | ||||
|     def __init__(self, stream_reader, client_connected_cb=None, loop=None): | ||||
|         self._stream_reader = stream_reader | ||||
|         self._stream_writer = None | ||||
|         self._drain_waiter = None | ||||
|         self._paused = False | ||||
|         self._client_connected_cb = client_connected_cb | ||||
|         self._loop = loop  # May be None; we may never need it. | ||||
| 
 | ||||
|     def connection_made(self, transport): | ||||
|         self._stream_reader.set_transport(transport) | ||||
|         if self._client_connected_cb is not None: | ||||
|             self._stream_writer = StreamWriter(transport, self, | ||||
|                                                self._stream_reader, | ||||
|                                                self._loop) | ||||
|             res = self._client_connected_cb(self._stream_reader, | ||||
|                                             self._stream_writer) | ||||
|             if tasks.iscoroutine(res): | ||||
|                 tasks.Task(res, loop=self._loop) | ||||
| 
 | ||||
|     def connection_lost(self, exc): | ||||
|         if exc is None: | ||||
|  |  | |||
|  | @ -359,6 +359,72 @@ def read_a_line(): | |||
|         test_utils.run_briefly(self.loop) | ||||
|         self.assertIs(stream._waiter, None) | ||||
| 
 | ||||
|     def test_start_server(self): | ||||
| 
 | ||||
|         class MyServer: | ||||
| 
 | ||||
|             def __init__(self, loop): | ||||
|                 self.server = None | ||||
|                 self.loop = loop | ||||
| 
 | ||||
|             @tasks.coroutine | ||||
|             def handle_client(self, client_reader, client_writer): | ||||
|                 data = yield from client_reader.readline() | ||||
|                 client_writer.write(data) | ||||
| 
 | ||||
|             def start(self): | ||||
|                 self.server = self.loop.run_until_complete( | ||||
|                     streams.start_server(self.handle_client, | ||||
|                                          '127.0.0.1', 12345, | ||||
|                                          loop=self.loop)) | ||||
| 
 | ||||
|             def handle_client_callback(self, client_reader, client_writer): | ||||
|                 task = tasks.Task(client_reader.readline(), loop=self.loop) | ||||
| 
 | ||||
|                 def done(task): | ||||
|                     client_writer.write(task.result()) | ||||
| 
 | ||||
|                 task.add_done_callback(done) | ||||
| 
 | ||||
|             def start_callback(self): | ||||
|                 self.server = self.loop.run_until_complete( | ||||
|                     streams.start_server(self.handle_client_callback, | ||||
|                                          '127.0.0.1', 12345, | ||||
|                                          loop=self.loop)) | ||||
| 
 | ||||
|             def stop(self): | ||||
|                 if self.server is not None: | ||||
|                     self.server.close() | ||||
|                     self.loop.run_until_complete(self.server.wait_closed()) | ||||
|                     self.server = None | ||||
| 
 | ||||
|         @tasks.coroutine | ||||
|         def client(): | ||||
|             reader, writer = yield from streams.open_connection( | ||||
|                 '127.0.0.1', 12345, loop=self.loop) | ||||
|             # send a line | ||||
|             writer.write(b"hello world!\n") | ||||
|             # read it back | ||||
|             msgback = yield from reader.readline() | ||||
|             writer.close() | ||||
|             return msgback | ||||
| 
 | ||||
|         # test the server variant with a coroutine as client handler | ||||
|         server = MyServer(self.loop) | ||||
|         server.start() | ||||
|         msg = self.loop.run_until_complete(tasks.Task(client(), | ||||
|                                                       loop=self.loop)) | ||||
|         server.stop() | ||||
|         self.assertEqual(msg, b"hello world!\n") | ||||
| 
 | ||||
|         # test the server variant with a callback as client handler | ||||
|         server = MyServer(self.loop) | ||||
|         server.start_callback() | ||||
|         msg = self.loop.run_until_complete(tasks.Task(client(), | ||||
|                                                       loop=self.loop)) | ||||
|         server.stop() | ||||
|         self.assertEqual(msg, b"hello world!\n") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Guido van Rossum
						Guido van Rossum