| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  | """Tests for sendfile functionality.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import asyncio | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import socket | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import tempfile | 
					
						
							|  |  |  | import unittest | 
					
						
							|  |  |  | from asyncio import base_events | 
					
						
							|  |  |  | from asyncio import constants | 
					
						
							|  |  |  | from unittest import mock | 
					
						
							|  |  |  | from test import support | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  | from test.support import socket_helper | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  | from test.test_asyncio import utils as test_utils | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     import ssl | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     ssl = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-01-07 23:55:57 +01:00
										 |  |  | def tearDownModule(): | 
					
						
							|  |  |  |     asyncio.set_event_loop_policy(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  | class MySendfileProto(asyncio.Protocol): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, loop=None, close_after=0): | 
					
						
							|  |  |  |         self.transport = None | 
					
						
							|  |  |  |         self.state = 'INITIAL' | 
					
						
							|  |  |  |         self.nbytes = 0 | 
					
						
							|  |  |  |         if loop is not None: | 
					
						
							|  |  |  |             self.connected = loop.create_future() | 
					
						
							|  |  |  |             self.done = loop.create_future() | 
					
						
							|  |  |  |         self.data = bytearray() | 
					
						
							|  |  |  |         self.close_after = close_after | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_made(self, transport): | 
					
						
							|  |  |  |         self.transport = transport | 
					
						
							|  |  |  |         assert self.state == 'INITIAL', self.state | 
					
						
							|  |  |  |         self.state = 'CONNECTED' | 
					
						
							|  |  |  |         if self.connected: | 
					
						
							|  |  |  |             self.connected.set_result(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def eof_received(self): | 
					
						
							|  |  |  |         assert self.state == 'CONNECTED', self.state | 
					
						
							|  |  |  |         self.state = 'EOF' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_lost(self, exc): | 
					
						
							|  |  |  |         assert self.state in ('CONNECTED', 'EOF'), self.state | 
					
						
							|  |  |  |         self.state = 'CLOSED' | 
					
						
							|  |  |  |         if self.done: | 
					
						
							|  |  |  |             self.done.set_result(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def data_received(self, data): | 
					
						
							|  |  |  |         assert self.state == 'CONNECTED', self.state | 
					
						
							|  |  |  |         self.nbytes += len(data) | 
					
						
							|  |  |  |         self.data.extend(data) | 
					
						
							|  |  |  |         super().data_received(data) | 
					
						
							|  |  |  |         if self.close_after and self.nbytes >= self.close_after: | 
					
						
							|  |  |  |             self.transport.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyProto(asyncio.Protocol): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, loop): | 
					
						
							|  |  |  |         self.started = False | 
					
						
							|  |  |  |         self.closed = False | 
					
						
							|  |  |  |         self.data = bytearray() | 
					
						
							|  |  |  |         self.fut = loop.create_future() | 
					
						
							|  |  |  |         self.transport = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_made(self, transport): | 
					
						
							|  |  |  |         self.started = True | 
					
						
							|  |  |  |         self.transport = transport | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def data_received(self, data): | 
					
						
							|  |  |  |         self.data.extend(data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_lost(self, exc): | 
					
						
							|  |  |  |         self.closed = True | 
					
						
							|  |  |  |         self.fut.set_result(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def wait_closed(self): | 
					
						
							|  |  |  |         await self.fut | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SendfileBase: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-15 14:05:08 +03:00
										 |  |  |       # 128 KiB plus small unaligned to buffer chunk | 
					
						
							|  |  |  |     DATA = b"SendfileBaseData" * (1024 * 8 + 1) | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Reduce socket buffer size to test on relative small data sets. | 
					
						
							|  |  |  |     BUF_SIZE = 4 * 1024   # 4 KiB | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def create_event_loop(self): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def setUpClass(cls): | 
					
						
							|  |  |  |         with open(support.TESTFN, 'wb') as fp: | 
					
						
							|  |  |  |             fp.write(cls.DATA) | 
					
						
							|  |  |  |         super().setUpClass() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def tearDownClass(cls): | 
					
						
							|  |  |  |         support.unlink(support.TESTFN) | 
					
						
							|  |  |  |         super().tearDownClass() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def setUp(self): | 
					
						
							|  |  |  |         self.file = open(support.TESTFN, 'rb') | 
					
						
							|  |  |  |         self.addCleanup(self.file.close) | 
					
						
							|  |  |  |         self.loop = self.create_event_loop() | 
					
						
							|  |  |  |         self.set_event_loop(self.loop) | 
					
						
							|  |  |  |         super().setUp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def tearDown(self): | 
					
						
							|  |  |  |         # just in case if we have transport close callbacks | 
					
						
							|  |  |  |         if not self.loop.is_closed(): | 
					
						
							|  |  |  |             test_utils.run_briefly(self.loop) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.doCleanups() | 
					
						
							|  |  |  |         support.gc_collect() | 
					
						
							|  |  |  |         super().tearDown() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def run_loop(self, coro): | 
					
						
							|  |  |  |         return self.loop.run_until_complete(coro) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SockSendfileMixin(SendfileBase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def setUpClass(cls): | 
					
						
							|  |  |  |         cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE | 
					
						
							|  |  |  |         constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 | 
					
						
							|  |  |  |         super().setUpClass() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def tearDownClass(cls): | 
					
						
							|  |  |  |         constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize | 
					
						
							|  |  |  |         super().tearDownClass() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def make_socket(self, cleanup=True): | 
					
						
							|  |  |  |         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
					
						
							|  |  |  |         sock.setblocking(False) | 
					
						
							|  |  |  |         if cleanup: | 
					
						
							|  |  |  |             self.addCleanup(sock.close) | 
					
						
							|  |  |  |         return sock | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reduce_receive_buffer_size(self, sock): | 
					
						
							|  |  |  |         # Reduce receive socket buffer size to test on relative | 
					
						
							|  |  |  |         # small data sets. | 
					
						
							|  |  |  |         sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reduce_send_buffer_size(self, sock, transport=None): | 
					
						
							|  |  |  |         # Reduce send socket buffer size to test on relative small data sets. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # On macOS, SO_SNDBUF is reset by connect(). So this method | 
					
						
							|  |  |  |         # should be called after the socket is connected. | 
					
						
							|  |  |  |         sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if transport is not None: | 
					
						
							|  |  |  |             transport.set_write_buffer_limits(high=self.BUF_SIZE) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def prepare_socksendfile(self): | 
					
						
							|  |  |  |         proto = MyProto(self.loop) | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |         port = socket_helper.find_unused_port() | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |         srv_sock = self.make_socket(cleanup=False) | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |         srv_sock.bind((socket_helper.HOST, port)) | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |         server = self.run_loop(self.loop.create_server( | 
					
						
							|  |  |  |             lambda: proto, sock=srv_sock)) | 
					
						
							|  |  |  |         self.reduce_receive_buffer_size(srv_sock) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         sock = self.make_socket() | 
					
						
							|  |  |  |         self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) | 
					
						
							|  |  |  |         self.reduce_send_buffer_size(sock) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def cleanup(): | 
					
						
							|  |  |  |             if proto.transport is not None: | 
					
						
							|  |  |  |                 # can be None if the task was cancelled before | 
					
						
							|  |  |  |                 # connection_made callback | 
					
						
							|  |  |  |                 proto.transport.close() | 
					
						
							|  |  |  |                 self.run_loop(proto.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             server.close() | 
					
						
							|  |  |  |             self.run_loop(server.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.addCleanup(cleanup) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return sock, proto | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sock_sendfile_success(self): | 
					
						
							|  |  |  |         sock, proto = self.prepare_socksendfile() | 
					
						
							|  |  |  |         ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) | 
					
						
							|  |  |  |         sock.close() | 
					
						
							|  |  |  |         self.run_loop(proto.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sock_sendfile_with_offset_and_count(self): | 
					
						
							|  |  |  |         sock, proto = self.prepare_socksendfile() | 
					
						
							|  |  |  |         ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, | 
					
						
							|  |  |  |                                                     1000, 2000)) | 
					
						
							|  |  |  |         sock.close() | 
					
						
							|  |  |  |         self.run_loop(proto.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(proto.data, self.DATA[1000:3000]) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 3000) | 
					
						
							|  |  |  |         self.assertEqual(ret, 2000) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sock_sendfile_zero_size(self): | 
					
						
							|  |  |  |         sock, proto = self.prepare_socksendfile() | 
					
						
							|  |  |  |         with tempfile.TemporaryFile() as f: | 
					
						
							|  |  |  |             ret = self.run_loop(self.loop.sock_sendfile(sock, f, | 
					
						
							|  |  |  |                                                         0, None)) | 
					
						
							|  |  |  |         sock.close() | 
					
						
							|  |  |  |         self.run_loop(proto.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(ret, 0) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sock_sendfile_mix_with_regular_send(self): | 
					
						
							|  |  |  |         buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB | 
					
						
							|  |  |  |         sock, proto = self.prepare_socksendfile() | 
					
						
							|  |  |  |         self.run_loop(self.loop.sock_sendall(sock, buf)) | 
					
						
							|  |  |  |         ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) | 
					
						
							|  |  |  |         self.run_loop(self.loop.sock_sendall(sock, buf)) | 
					
						
							|  |  |  |         sock.close() | 
					
						
							|  |  |  |         self.run_loop(proto.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         expected = buf + self.DATA + buf | 
					
						
							|  |  |  |         self.assertEqual(proto.data, expected) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SendfileMixin(SendfileBase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Note: sendfile via SSL transport is equal to sendfile fallback | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def prepare_sendfile(self, *, is_ssl=False, close_after=0): | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |         port = socket_helper.find_unused_port() | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |         srv_proto = MySendfileProto(loop=self.loop, | 
					
						
							|  |  |  |                                     close_after=close_after) | 
					
						
							|  |  |  |         if is_ssl: | 
					
						
							|  |  |  |             if not ssl: | 
					
						
							|  |  |  |                 self.skipTest("No ssl module") | 
					
						
							|  |  |  |             srv_ctx = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |             cli_ctx = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             srv_ctx = None | 
					
						
							|  |  |  |             cli_ctx = None | 
					
						
							|  |  |  |         srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |         srv_sock.bind((socket_helper.HOST, port)) | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |         server = self.run_loop(self.loop.create_server( | 
					
						
							|  |  |  |             lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) | 
					
						
							|  |  |  |         self.reduce_receive_buffer_size(srv_sock) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if is_ssl: | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |             server_hostname = socket_helper.HOST | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |         else: | 
					
						
							|  |  |  |             server_hostname = None | 
					
						
							|  |  |  |         cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
					
						
							| 
									
										
										
										
											2020-04-25 10:06:29 +03:00
										 |  |  |         cli_sock.connect((socket_helper.HOST, port)) | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         cli_proto = MySendfileProto(loop=self.loop) | 
					
						
							|  |  |  |         tr, pr = self.run_loop(self.loop.create_connection( | 
					
						
							|  |  |  |             lambda: cli_proto, sock=cli_sock, | 
					
						
							|  |  |  |             ssl=cli_ctx, server_hostname=server_hostname)) | 
					
						
							|  |  |  |         self.reduce_send_buffer_size(cli_sock, transport=tr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def cleanup(): | 
					
						
							|  |  |  |             srv_proto.transport.close() | 
					
						
							|  |  |  |             cli_proto.transport.close() | 
					
						
							|  |  |  |             self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |             self.run_loop(cli_proto.done) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             server.close() | 
					
						
							|  |  |  |             self.run_loop(server.wait_closed()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.addCleanup(cleanup) | 
					
						
							|  |  |  |         return srv_proto, cli_proto | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") | 
					
						
							|  |  |  |     def test_sendfile_not_supported(self): | 
					
						
							|  |  |  |         tr, pr = self.run_loop( | 
					
						
							|  |  |  |             self.loop.create_datagram_endpoint( | 
					
						
							|  |  |  |                 asyncio.DatagramProtocol, | 
					
						
							|  |  |  |                 family=socket.AF_INET)) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with self.assertRaisesRegex(RuntimeError, "not supported"): | 
					
						
							|  |  |  |                 self.run_loop( | 
					
						
							|  |  |  |                     self.loop.sendfile(tr, self.file)) | 
					
						
							|  |  |  |             self.assertEqual(0, self.file.tell()) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             # don't use self.addCleanup because it produces resource warning | 
					
						
							|  |  |  |             tr.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_force_fallback(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sendfile_native(transp, file, offset, count): | 
					
						
							|  |  |  |             # to raise SendfileNotAvailableError | 
					
						
							|  |  |  |             return base_events.BaseEventLoop._sendfile_native( | 
					
						
							|  |  |  |                 self.loop, transp, file, offset, count) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop._sendfile_native = sendfile_native | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_force_unsupported_native(self): | 
					
						
							|  |  |  |         if sys.platform == 'win32': | 
					
						
							|  |  |  |             if isinstance(self.loop, asyncio.ProactorEventLoop): | 
					
						
							|  |  |  |                 self.skipTest("Fails on proactor event loop") | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sendfile_native(transp, file, offset, count): | 
					
						
							|  |  |  |             # to raise SendfileNotAvailableError | 
					
						
							|  |  |  |             return base_events.BaseEventLoop._sendfile_native( | 
					
						
							|  |  |  |                 self.loop, transp, file, offset, count) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop._sendfile_native = sendfile_native | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, | 
					
						
							|  |  |  |                                     "not supported"): | 
					
						
							|  |  |  |             self.run_loop( | 
					
						
							|  |  |  |                 self.loop.sendfile(cli_proto.transport, self.file, | 
					
						
							|  |  |  |                                    fallback=False)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, 0) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_ssl(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_for_closing_transp(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         with self.assertRaisesRegex(RuntimeError, "is closing"): | 
					
						
							|  |  |  |             self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, 0) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_pre_and_post_data(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  |         PREFIX = b'PREFIX__' * 1024  # 8 KiB | 
					
						
							|  |  |  |         SUFFIX = b'--SUFFIX' * 1024  # 8 KiB | 
					
						
							|  |  |  |         cli_proto.transport.write(PREFIX) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.write(SUFFIX) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_ssl_pre_and_post_data(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) | 
					
						
							|  |  |  |         PREFIX = b'zxcvbnm' * 1024 | 
					
						
							|  |  |  |         SUFFIX = b'0987654321' * 1024 | 
					
						
							|  |  |  |         cli_proto.transport.write(PREFIX) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.write(SUFFIX) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_partial(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, 100) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, 100) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA[1000:1100]) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 1100) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_ssl_partial(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, 100) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, 100) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA[1000:1100]) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), 1100) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_close_peer_after_receiving(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile( | 
					
						
							|  |  |  |             close_after=len(self.DATA)) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         cli_proto.transport.close() | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_ssl_close_peer_after_receiving(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile( | 
					
						
							|  |  |  |             is_ssl=True, close_after=len(self.DATA)) | 
					
						
							|  |  |  |         ret = self.run_loop( | 
					
						
							|  |  |  |             self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.nbytes, len(self.DATA)) | 
					
						
							|  |  |  |         self.assertEqual(srv_proto.data, self.DATA) | 
					
						
							|  |  |  |         self.assertEqual(self.file.tell(), len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-16 13:52:26 +02:00
										 |  |  |     # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been | 
					
						
							|  |  |  |     # established has no effect. Due to its age, this bug affects both Oracle | 
					
						
							|  |  |  |     # Solaris as well as all other OpenSolaris forks (unless they fixed it | 
					
						
							|  |  |  |     # themselves). | 
					
						
							|  |  |  |     @unittest.skipIf(sys.platform.startswith('sunos'), | 
					
						
							|  |  |  |                      "Doesn't work on Solaris") | 
					
						
							| 
									
										
										
										
											2018-10-09 07:52:57 +03:00
										 |  |  |     def test_sendfile_close_peer_in_the_middle_of_receiving(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) | 
					
						
							|  |  |  |         with self.assertRaises(ConnectionError): | 
					
						
							|  |  |  |             self.run_loop( | 
					
						
							|  |  |  |                 self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), | 
					
						
							|  |  |  |                         srv_proto.nbytes) | 
					
						
							|  |  |  |         self.assertTrue(1024 <= self.file.tell() < len(self.DATA), | 
					
						
							|  |  |  |                         self.file.tell()) | 
					
						
							|  |  |  |         self.assertTrue(cli_proto.transport.is_closing()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def sendfile_native(transp, file, offset, count): | 
					
						
							|  |  |  |             # to raise SendfileNotAvailableError | 
					
						
							|  |  |  |             return base_events.BaseEventLoop._sendfile_native( | 
					
						
							|  |  |  |                 self.loop, transp, file, offset, count) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop._sendfile_native = sendfile_native | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) | 
					
						
							|  |  |  |         with self.assertRaises(ConnectionError): | 
					
						
							|  |  |  |             self.run_loop( | 
					
						
							|  |  |  |                 self.loop.sendfile(cli_proto.transport, self.file)) | 
					
						
							|  |  |  |         self.run_loop(srv_proto.done) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), | 
					
						
							|  |  |  |                         srv_proto.nbytes) | 
					
						
							|  |  |  |         self.assertTrue(1024 <= self.file.tell() < len(self.DATA), | 
					
						
							|  |  |  |                         self.file.tell()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @unittest.skipIf(not hasattr(os, 'sendfile'), | 
					
						
							|  |  |  |                      "Don't have native sendfile support") | 
					
						
							|  |  |  |     def test_sendfile_prevents_bare_write(self): | 
					
						
							|  |  |  |         srv_proto, cli_proto = self.prepare_sendfile() | 
					
						
							|  |  |  |         fut = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def coro(): | 
					
						
							|  |  |  |             fut.set_result(None) | 
					
						
							|  |  |  |             return await self.loop.sendfile(cli_proto.transport, self.file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         t = self.loop.create_task(coro()) | 
					
						
							|  |  |  |         self.run_loop(fut) | 
					
						
							|  |  |  |         with self.assertRaisesRegex(RuntimeError, | 
					
						
							|  |  |  |                                     "sendfile is in progress"): | 
					
						
							|  |  |  |             cli_proto.transport.write(b'data') | 
					
						
							|  |  |  |         ret = self.run_loop(t) | 
					
						
							|  |  |  |         self.assertEqual(ret, len(self.DATA)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_sendfile_no_fallback_for_fallback_transport(self): | 
					
						
							|  |  |  |         transport = mock.Mock() | 
					
						
							|  |  |  |         transport.is_closing.side_effect = lambda: False | 
					
						
							|  |  |  |         transport._sendfile_compatible = constants._SendfileMode.FALLBACK | 
					
						
							|  |  |  |         with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): | 
					
						
							|  |  |  |             self.loop.run_until_complete( | 
					
						
							|  |  |  |                 self.loop.sendfile(transport, None, fallback=False)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SendfileTestsBase(SendfileMixin, SockSendfileMixin): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if sys.platform == 'win32': | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class SelectEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def create_event_loop(self): | 
					
						
							|  |  |  |             return asyncio.SelectorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class ProactorEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                  test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def create_event_loop(self): | 
					
						
							|  |  |  |             return asyncio.ProactorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     import selectors | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(selectors, 'KqueueSelector'): | 
					
						
							|  |  |  |         class KqueueEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                    test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def create_event_loop(self): | 
					
						
							|  |  |  |                 return asyncio.SelectorEventLoop( | 
					
						
							|  |  |  |                     selectors.KqueueSelector()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(selectors, 'EpollSelector'): | 
					
						
							|  |  |  |         class EPollEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                   test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def create_event_loop(self): | 
					
						
							|  |  |  |                 return asyncio.SelectorEventLoop(selectors.EpollSelector()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(selectors, 'PollSelector'): | 
					
						
							|  |  |  |         class PollEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                  test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def create_event_loop(self): | 
					
						
							|  |  |  |                 return asyncio.SelectorEventLoop(selectors.PollSelector()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Should always exist. | 
					
						
							|  |  |  |     class SelectEventLoopTests(SendfileTestsBase, | 
					
						
							|  |  |  |                                test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def create_event_loop(self): | 
					
						
							|  |  |  |             return asyncio.SelectorEventLoop(selectors.SelectSelector()) | 
					
						
							| 
									
										
										
										
											2022-01-22 03:54:07 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     unittest.main() |