| 
									
										
										
										
											2018-01-28 16:30:26 -05:00
										 |  |  | import asyncio | 
					
						
							|  |  |  | import unittest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from test.test_asyncio import functional as func_tests | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-01 20:34:09 -07:00
										 |  |  | def tearDownModule(): | 
					
						
							|  |  |  |     asyncio.set_event_loop_policy(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-01-28 16:30:26 -05:00
										 |  |  | class ReceiveStuffProto(asyncio.BufferedProtocol): | 
					
						
							|  |  |  |     def __init__(self, cb, con_lost_fut): | 
					
						
							|  |  |  |         self.cb = cb | 
					
						
							|  |  |  |         self.con_lost_fut = con_lost_fut | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |     def get_buffer(self, sizehint): | 
					
						
							| 
									
										
										
										
											2018-01-28 16:30:26 -05:00
										 |  |  |         self.buffer = bytearray(100) | 
					
						
							|  |  |  |         return self.buffer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def buffer_updated(self, nbytes): | 
					
						
							|  |  |  |         self.cb(self.buffer[:nbytes]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_lost(self, exc): | 
					
						
							|  |  |  |         if exc is None: | 
					
						
							|  |  |  |             self.con_lost_fut.set_result(None) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.con_lost_fut.set_exception(exc) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_buffered_proto_create_connection(self): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         NOISE = b'12345678+' * 1024 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							|  |  |  |             data = b'' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def on_buf(buf): | 
					
						
							|  |  |  |                 nonlocal data | 
					
						
							|  |  |  |                 data += buf | 
					
						
							|  |  |  |                 if data == NOISE: | 
					
						
							|  |  |  |                     tr.write(b'1') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             conn_lost_fut = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr, pr = await self.loop.create_connection( | 
					
						
							|  |  |  |                 lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             await conn_lost_fut | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def on_server_client(reader, writer): | 
					
						
							|  |  |  |             writer.write(NOISE) | 
					
						
							|  |  |  |             await reader.readexactly(1) | 
					
						
							|  |  |  |             writer.close() | 
					
						
							|  |  |  |             await writer.wait_closed() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-29 21:59:55 -07:00
										 |  |  |         srv = self.loop.run_until_complete( | 
					
						
							|  |  |  |             asyncio.start_server( | 
					
						
							|  |  |  |                 on_server_client, '127.0.0.1', 0)) | 
					
						
							| 
									
										
										
										
											2018-01-28 16:30:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         addr = srv.sockets[0].getsockname() | 
					
						
							|  |  |  |         self.loop.run_until_complete( | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |             asyncio.wait_for(client(addr), 5)) | 
					
						
							| 
									
										
										
										
											2018-01-28 16:30:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         srv.close() | 
					
						
							|  |  |  |         self.loop.run_until_complete(srv.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BufferedProtocolSelectorTests(BaseTestBufferedProtocol, | 
					
						
							|  |  |  |                                     unittest.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         return asyncio.SelectorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') | 
					
						
							|  |  |  | class BufferedProtocolProactorTests(BaseTestBufferedProtocol, | 
					
						
							|  |  |  |                                     unittest.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         return asyncio.ProactorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     unittest.main() |