| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | """Utilities shared by tests.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import collections | 
					
						
							|  |  |  | import contextlib | 
					
						
							|  |  |  | import io | 
					
						
							| 
									
										
										
										
											2014-07-14 22:26:34 +02:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2014-02-18 18:02:19 -05:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | import socket | 
					
						
							|  |  |  | import socketserver | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | import tempfile | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | import threading | 
					
						
							| 
									
										
										
										
											2013-10-20 01:51:25 +02:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2014-06-18 01:36:32 +02:00
										 |  |  | import unittest | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  | import weakref | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-26 10:25:02 +01:00
										 |  |  | from unittest import mock | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | from http.server import HTTPServer | 
					
						
							| 
									
										
										
										
											2014-02-20 10:37:27 +01:00
										 |  |  | from wsgiref.simple_server import WSGIRequestHandler, WSGIServer | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | try: | 
					
						
							|  |  |  |     import ssl | 
					
						
							|  |  |  | except ImportError:  # pragma: no cover | 
					
						
							|  |  |  |     ssl = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from . import base_events | 
					
						
							| 
									
										
										
										
											2015-11-20 12:57:34 -05:00
										 |  |  | from . import compat | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | from . import events | 
					
						
							| 
									
										
										
										
											2014-03-06 01:00:36 +01:00
										 |  |  | from . import futures | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | from . import selectors | 
					
						
							| 
									
										
										
										
											2014-03-06 01:00:36 +01:00
										 |  |  | from . import tasks | 
					
						
							| 
									
										
										
										
											2014-06-29 00:46:45 +02:00
										 |  |  | from .coroutines import coroutine | 
					
						
							| 
									
										
										
										
											2014-07-14 22:26:34 +02:00
										 |  |  | from .log import logger | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if sys.platform == 'win32':  # pragma: no cover | 
					
						
							|  |  |  |     from .windows_utils import socketpair | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     from socket import socketpair  # pragma: no cover | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def dummy_ssl_context(): | 
					
						
							|  |  |  |     if ssl is None: | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return ssl.SSLContext(ssl.PROTOCOL_SSLv23) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def run_briefly(loop): | 
					
						
							| 
									
										
										
										
											2014-06-29 00:46:45 +02:00
										 |  |  |     @coroutine | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     def once(): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  |     gen = once() | 
					
						
							| 
									
										
										
										
											2014-07-08 11:29:25 +02:00
										 |  |  |     t = loop.create_task(gen) | 
					
						
							| 
									
										
										
										
											2014-06-30 14:51:04 +02:00
										 |  |  |     # Don't log a warning if the task is not done after run_until_complete(). | 
					
						
							|  |  |  |     # It occurs if the loop is stopped or if a task raises a BaseException. | 
					
						
							|  |  |  |     t._log_destroy_pending = False | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     try: | 
					
						
							|  |  |  |         loop.run_until_complete(t) | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         gen.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-03-06 01:00:36 +01:00
										 |  |  | def run_until(loop, pred, timeout=30): | 
					
						
							|  |  |  |     deadline = time.time() + timeout | 
					
						
							| 
									
										
										
										
											2013-10-20 01:51:25 +02:00
										 |  |  |     while not pred(): | 
					
						
							|  |  |  |         if timeout is not None: | 
					
						
							|  |  |  |             timeout = deadline - time.time() | 
					
						
							|  |  |  |             if timeout <= 0: | 
					
						
							| 
									
										
										
										
											2014-03-06 01:00:36 +01:00
										 |  |  |                 raise futures.TimeoutError() | 
					
						
							|  |  |  |         loop.run_until_complete(tasks.sleep(0.001, loop=loop)) | 
					
						
							| 
									
										
										
										
											2013-10-20 01:51:25 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | def run_once(loop): | 
					
						
							| 
									
										
										
										
											2015-11-19 13:28:47 -08:00
										 |  |  |     """Legacy API to run once through the event loop.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     This is the recommended pattern for test code.  It will poll the | 
					
						
							|  |  |  |     selector once and run all callbacks scheduled in response to I/O | 
					
						
							|  |  |  |     events. | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2015-11-19 13:28:47 -08:00
										 |  |  |     loop.call_soon(loop.stop) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     loop.run_forever() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | class SilentWSGIRequestHandler(WSGIRequestHandler): | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |     def get_stderr(self): | 
					
						
							|  |  |  |         return io.StringIO() | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |     def log_message(self, format, *args): | 
					
						
							|  |  |  |         pass | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | class SilentWSGIServer(WSGIServer): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-15 16:58:21 +02:00
										 |  |  |     request_timeout = 2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(self): | 
					
						
							|  |  |  |         request, client_addr = super().get_request() | 
					
						
							|  |  |  |         request.settimeout(self.request_timeout) | 
					
						
							|  |  |  |         return request, client_addr | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |     def handle_error(self, request, client_address): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SSLWSGIServerMixin: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def finish_request(self, request, client_address): | 
					
						
							|  |  |  |         # The relative location of our test directory (which | 
					
						
							|  |  |  |         # contains the ssl key and certificate files) differs | 
					
						
							|  |  |  |         # between the stdlib and stand-alone asyncio. | 
					
						
							|  |  |  |         # Prefer our own if we can find it. | 
					
						
							|  |  |  |         here = os.path.join(os.path.dirname(__file__), '..', 'tests') | 
					
						
							|  |  |  |         if not os.path.isdir(here): | 
					
						
							|  |  |  |             here = os.path.join(os.path.dirname(os.__file__), | 
					
						
							|  |  |  |                                 'test', 'test_asyncio') | 
					
						
							|  |  |  |         keyfile = os.path.join(here, 'ssl_key.pem') | 
					
						
							|  |  |  |         certfile = os.path.join(here, 'ssl_cert.pem') | 
					
						
							| 
									
										
										
										
											2016-09-10 23:23:33 +02:00
										 |  |  |         context = ssl.SSLContext() | 
					
						
							|  |  |  |         context.load_cert_chain(certfile, keyfile) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ssock = context.wrap_socket(request, server_side=True) | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |         try: | 
					
						
							|  |  |  |             self.RequestHandlerClass(ssock, client_address, self) | 
					
						
							|  |  |  |             ssock.close() | 
					
						
							|  |  |  |         except OSError: | 
					
						
							|  |  |  |             # maybe socket has been closed by peer | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def app(environ, start_response): | 
					
						
							|  |  |  |         status = '200 OK' | 
					
						
							|  |  |  |         headers = [('Content-type', 'text/plain')] | 
					
						
							|  |  |  |         start_response(status, headers) | 
					
						
							|  |  |  |         return [b'Test message'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Run the test WSGI server in a separate thread in order not to | 
					
						
							|  |  |  |     # interfere with event handling in the main thread | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |     server_class = server_ssl_cls if use_ssl else server_cls | 
					
						
							|  |  |  |     httpd = server_class(address, SilentWSGIRequestHandler) | 
					
						
							|  |  |  |     httpd.set_app(app) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     httpd.address = httpd.server_address | 
					
						
							| 
									
										
										
										
											2014-10-15 16:58:21 +02:00
										 |  |  |     server_thread = threading.Thread( | 
					
						
							|  |  |  |         target=lambda: httpd.serve_forever(poll_interval=0.05)) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     server_thread.start() | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         yield httpd | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         httpd.shutdown() | 
					
						
							| 
									
										
										
										
											2013-10-20 23:26:23 +02:00
										 |  |  |         httpd.server_close() | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |         server_thread.join() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | if hasattr(socket, 'AF_UNIX'): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def server_bind(self): | 
					
						
							|  |  |  |             socketserver.UnixStreamServer.server_bind(self) | 
					
						
							|  |  |  |             self.server_name = '127.0.0.1' | 
					
						
							|  |  |  |             self.server_port = 80 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class UnixWSGIServer(UnixHTTPServer, WSGIServer): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-15 16:58:21 +02:00
										 |  |  |         request_timeout = 2 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |         def server_bind(self): | 
					
						
							|  |  |  |             UnixHTTPServer.server_bind(self) | 
					
						
							|  |  |  |             self.setup_environ() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def get_request(self): | 
					
						
							|  |  |  |             request, client_addr = super().get_request() | 
					
						
							| 
									
										
										
										
											2014-10-15 16:58:21 +02:00
										 |  |  |             request.settimeout(self.request_timeout) | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  |             # Code in the stdlib expects that get_request | 
					
						
							|  |  |  |             # will return a socket and a tuple (host, port). | 
					
						
							|  |  |  |             # However, this isn't true for UNIX sockets, | 
					
						
							|  |  |  |             # as the second return value will be a path; | 
					
						
							|  |  |  |             # hence we return some fake data sufficient | 
					
						
							|  |  |  |             # to get the tests going | 
					
						
							|  |  |  |             return request, ('127.0.0.1', '') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class SilentUnixWSGIServer(UnixWSGIServer): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def handle_error(self, request, client_address): | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def gen_unix_socket_path(): | 
					
						
							|  |  |  |         with tempfile.NamedTemporaryFile() as file: | 
					
						
							|  |  |  |             return file.name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @contextlib.contextmanager | 
					
						
							|  |  |  |     def unix_socket_path(): | 
					
						
							|  |  |  |         path = gen_unix_socket_path() | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             yield path | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 os.unlink(path) | 
					
						
							|  |  |  |             except OSError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @contextlib.contextmanager | 
					
						
							|  |  |  |     def run_test_unix_server(*, use_ssl=False): | 
					
						
							|  |  |  |         with unix_socket_path() as path: | 
					
						
							|  |  |  |             yield from _run_test_server(address=path, use_ssl=use_ssl, | 
					
						
							|  |  |  |                                         server_cls=SilentUnixWSGIServer, | 
					
						
							|  |  |  |                                         server_ssl_cls=UnixSSLWSGIServer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							|  |  |  | def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): | 
					
						
							|  |  |  |     yield from _run_test_server(address=(host, port), use_ssl=use_ssl, | 
					
						
							|  |  |  |                                 server_cls=SilentWSGIServer, | 
					
						
							|  |  |  |                                 server_ssl_cls=SSLWSGIServer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | def make_test_protocol(base): | 
					
						
							|  |  |  |     dct = {} | 
					
						
							|  |  |  |     for name in dir(base): | 
					
						
							|  |  |  |         if name.startswith('__') and name.endswith('__'): | 
					
						
							|  |  |  |             # skip magic names | 
					
						
							|  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2014-02-11 11:34:30 +01:00
										 |  |  |         dct[name] = MockCallback(return_value=None) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     return type('TestProtocol', (base,) + base.__bases__, dct)() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestSelector(selectors.BaseSelector): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-12-01 11:04:17 +01:00
										 |  |  |     def __init__(self): | 
					
						
							|  |  |  |         self.keys = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def register(self, fileobj, events, data=None): | 
					
						
							|  |  |  |         key = selectors.SelectorKey(fileobj, 0, events, data) | 
					
						
							|  |  |  |         self.keys[fileobj] = key | 
					
						
							|  |  |  |         return key | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def unregister(self, fileobj): | 
					
						
							|  |  |  |         return self.keys.pop(fileobj) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     def select(self, timeout): | 
					
						
							|  |  |  |         return [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-12-01 11:04:17 +01:00
										 |  |  |     def get_map(self): | 
					
						
							|  |  |  |         return self.keys | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | class TestLoop(base_events.BaseEventLoop): | 
					
						
							|  |  |  |     """Loop for unittests.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     It manages self time directly. | 
					
						
							|  |  |  |     If something scheduled to be executed later then | 
					
						
							|  |  |  |     on next loop iteration after all ready handlers done | 
					
						
							|  |  |  |     generator passed to __init__ is calling. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Generator should be like this: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def gen(): | 
					
						
							|  |  |  |             ... | 
					
						
							|  |  |  |             when = yield ... | 
					
						
							|  |  |  |             ... = yield time_advance | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 22:27:48 -05:00
										 |  |  |     Value returned by yield is absolute time of next scheduled handler. | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     Value passed to yield is time advance to move loop's time forward. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, gen=None): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if gen is None: | 
					
						
							|  |  |  |             def gen(): | 
					
						
							|  |  |  |                 yield | 
					
						
							|  |  |  |             self._check_on_close = False | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self._check_on_close = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self._gen = gen() | 
					
						
							|  |  |  |         next(self._gen) | 
					
						
							|  |  |  |         self._time = 0 | 
					
						
							| 
									
										
										
										
											2014-02-11 09:03:47 +01:00
										 |  |  |         self._clock_resolution = 1e-9 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |         self._timers = [] | 
					
						
							|  |  |  |         self._selector = TestSelector() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.readers = {} | 
					
						
							|  |  |  |         self.writers = {} | 
					
						
							|  |  |  |         self.reset_counters() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |         self._transports = weakref.WeakValueDictionary() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     def time(self): | 
					
						
							|  |  |  |         return self._time | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def advance_time(self, advance): | 
					
						
							|  |  |  |         """Move test time forward.""" | 
					
						
							|  |  |  |         if advance: | 
					
						
							|  |  |  |             self._time += advance | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def close(self): | 
					
						
							| 
									
										
										
										
											2015-01-15 00:04:21 +01:00
										 |  |  |         super().close() | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |         if self._check_on_close: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 self._gen.send(0) | 
					
						
							|  |  |  |             except StopIteration: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             else:  # pragma: no cover | 
					
						
							|  |  |  |                 raise AssertionError("Time generator is not finished") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |     def _add_reader(self, fd, callback, *args): | 
					
						
							| 
									
										
										
										
											2014-02-18 18:02:19 -05:00
										 |  |  |         self.readers[fd] = events.Handle(callback, args, self) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |     def _remove_reader(self, fd): | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |         self.remove_reader_count[fd] += 1 | 
					
						
							|  |  |  |         if fd in self.readers: | 
					
						
							|  |  |  |             del self.readers[fd] | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_reader(self, fd, callback, *args): | 
					
						
							|  |  |  |         assert fd in self.readers, 'fd {} is not registered'.format(fd) | 
					
						
							|  |  |  |         handle = self.readers[fd] | 
					
						
							|  |  |  |         assert handle._callback == callback, '{!r} != {!r}'.format( | 
					
						
							|  |  |  |             handle._callback, callback) | 
					
						
							|  |  |  |         assert handle._args == args, '{!r} != {!r}'.format( | 
					
						
							|  |  |  |             handle._args, args) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |     def _add_writer(self, fd, callback, *args): | 
					
						
							| 
									
										
										
										
											2014-02-18 18:02:19 -05:00
										 |  |  |         self.writers[fd] = events.Handle(callback, args, self) | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |     def _remove_writer(self, fd): | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |         self.remove_writer_count[fd] += 1 | 
					
						
							|  |  |  |         if fd in self.writers: | 
					
						
							|  |  |  |             del self.writers[fd] | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_writer(self, fd, callback, *args): | 
					
						
							|  |  |  |         assert fd in self.writers, 'fd {} is not registered'.format(fd) | 
					
						
							|  |  |  |         handle = self.writers[fd] | 
					
						
							|  |  |  |         assert handle._callback == callback, '{!r} != {!r}'.format( | 
					
						
							|  |  |  |             handle._callback, callback) | 
					
						
							|  |  |  |         assert handle._args == args, '{!r} != {!r}'.format( | 
					
						
							|  |  |  |             handle._args, args) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-10-05 17:48:59 -04:00
										 |  |  |     def _ensure_fd_no_transport(self, fd): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             transport = self._transports[fd] | 
					
						
							|  |  |  |         except KeyError: | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise RuntimeError( | 
					
						
							|  |  |  |                 'File descriptor {!r} is used by transport {!r}'.format( | 
					
						
							|  |  |  |                     fd, transport)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_reader(self, fd, callback, *args): | 
					
						
							|  |  |  |         """Add a reader callback.""" | 
					
						
							|  |  |  |         self._ensure_fd_no_transport(fd) | 
					
						
							|  |  |  |         return self._add_reader(fd, callback, *args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remove_reader(self, fd): | 
					
						
							|  |  |  |         """Remove a reader callback.""" | 
					
						
							|  |  |  |         self._ensure_fd_no_transport(fd) | 
					
						
							|  |  |  |         return self._remove_reader(fd) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_writer(self, fd, callback, *args): | 
					
						
							|  |  |  |         """Add a writer callback..""" | 
					
						
							|  |  |  |         self._ensure_fd_no_transport(fd) | 
					
						
							|  |  |  |         return self._add_writer(fd, callback, *args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remove_writer(self, fd): | 
					
						
							|  |  |  |         """Remove a writer callback.""" | 
					
						
							|  |  |  |         self._ensure_fd_no_transport(fd) | 
					
						
							|  |  |  |         return self._remove_writer(fd) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-17 13:40:50 -07:00
										 |  |  |     def reset_counters(self): | 
					
						
							|  |  |  |         self.remove_reader_count = collections.defaultdict(int) | 
					
						
							|  |  |  |         self.remove_writer_count = collections.defaultdict(int) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _run_once(self): | 
					
						
							|  |  |  |         super()._run_once() | 
					
						
							|  |  |  |         for when in self._timers: | 
					
						
							|  |  |  |             advance = self._gen.send(when) | 
					
						
							|  |  |  |             self.advance_time(advance) | 
					
						
							|  |  |  |         self._timers = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def call_at(self, when, callback, *args): | 
					
						
							|  |  |  |         self._timers.append(when) | 
					
						
							|  |  |  |         return super().call_at(when, callback, *args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _process_events(self, event_list): | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _write_to_self(self): | 
					
						
							|  |  |  |         pass | 
					
						
							| 
									
										
										
										
											2014-02-11 11:34:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-18 12:15:06 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-02-11 11:34:30 +01:00
										 |  |  | def MockCallback(**kwargs): | 
					
						
							| 
									
										
										
										
											2014-02-26 10:25:02 +01:00
										 |  |  |     return mock.Mock(spec=['__call__'], **kwargs) | 
					
						
							| 
									
										
										
										
											2014-02-18 18:02:19 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MockPattern(str): | 
					
						
							|  |  |  |     """A regex based str with a fuzzy __eq__.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Use this helper with 'mock.assert_called_with', or anywhere | 
					
						
							| 
									
										
										
										
											2014-02-18 22:27:48 -05:00
										 |  |  |     where a regex comparison between strings is needed. | 
					
						
							| 
									
										
										
										
											2014-02-18 18:02:19 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     For instance: | 
					
						
							|  |  |  |        mock_call.assert_called_with(MockPattern('spam.*ham')) | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     def __eq__(self, other): | 
					
						
							|  |  |  |         return bool(re.search(str(self), other, re.S)) | 
					
						
							| 
									
										
										
										
											2014-06-12 18:39:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_function_source(func): | 
					
						
							|  |  |  |     source = events._get_function_source(func) | 
					
						
							|  |  |  |     if source is None: | 
					
						
							|  |  |  |         raise ValueError("unable to get the source of %r" % (func,)) | 
					
						
							|  |  |  |     return source | 
					
						
							| 
									
										
										
										
											2014-06-18 01:36:32 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestCase(unittest.TestCase): | 
					
						
							|  |  |  |     def set_event_loop(self, loop, *, cleanup=True): | 
					
						
							|  |  |  |         assert loop is not None | 
					
						
							|  |  |  |         # ensure that the event loop is passed explicitly in asyncio | 
					
						
							|  |  |  |         events.set_event_loop(None) | 
					
						
							|  |  |  |         if cleanup: | 
					
						
							|  |  |  |             self.addCleanup(loop.close) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def new_test_loop(self, gen=None): | 
					
						
							|  |  |  |         loop = TestLoop(gen) | 
					
						
							|  |  |  |         self.set_event_loop(loop) | 
					
						
							|  |  |  |         return loop | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-02 20:07:11 -05:00
										 |  |  |     def unpatch_get_running_loop(self): | 
					
						
							|  |  |  |         events._get_running_loop = self._get_running_loop | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-11-04 14:29:28 -04:00
										 |  |  |     def setUp(self): | 
					
						
							|  |  |  |         self._get_running_loop = events._get_running_loop | 
					
						
							|  |  |  |         events._get_running_loop = lambda: None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-06-18 01:36:32 +02:00
										 |  |  |     def tearDown(self): | 
					
						
							| 
									
										
										
										
											2017-03-02 20:07:11 -05:00
										 |  |  |         self.unpatch_get_running_loop() | 
					
						
							| 
									
										
										
										
											2016-11-04 14:29:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-06-18 01:36:32 +02:00
										 |  |  |         events.set_event_loop(None) | 
					
						
							| 
									
										
										
										
											2014-07-14 22:26:34 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-02-02 18:36:31 +01:00
										 |  |  |         # Detect CPython bug #23353: ensure that yield/yield-from is not used | 
					
						
							|  |  |  |         # in an except block of a generator | 
					
						
							|  |  |  |         self.assertEqual(sys.exc_info(), (None, None, None)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-11-20 12:57:34 -05:00
										 |  |  |     if not compat.PY34: | 
					
						
							|  |  |  |         # Python 3.3 compatibility | 
					
						
							|  |  |  |         def subTest(self, *args, **kwargs): | 
					
						
							|  |  |  |             class EmptyCM: | 
					
						
							|  |  |  |                 def __enter__(self): | 
					
						
							|  |  |  |                     pass | 
					
						
							|  |  |  |                 def __exit__(self, *exc): | 
					
						
							|  |  |  |                     pass | 
					
						
							|  |  |  |             return EmptyCM() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-07-14 22:26:34 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							|  |  |  | def disable_logger(): | 
					
						
							|  |  |  |     """Context manager to disable asyncio logger.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     For example, it can be used to ignore warnings in debug mode. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     old_level = logger.level | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         logger.setLevel(logging.CRITICAL+1) | 
					
						
							|  |  |  |         yield | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         logger.setLevel(old_level) | 
					
						
							| 
									
										
										
										
											2014-08-25 23:20:52 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-16 19:31:17 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, | 
					
						
							|  |  |  |                             family=socket.AF_INET): | 
					
						
							| 
									
										
										
										
											2014-08-25 23:20:52 +02:00
										 |  |  |     """Create a mock of a non-blocking socket.""" | 
					
						
							| 
									
										
										
										
											2015-12-16 19:31:17 -05:00
										 |  |  |     sock = mock.MagicMock(socket.socket) | 
					
						
							|  |  |  |     sock.proto = proto | 
					
						
							|  |  |  |     sock.type = type | 
					
						
							|  |  |  |     sock.family = family | 
					
						
							| 
									
										
										
										
											2014-08-25 23:20:52 +02:00
										 |  |  |     sock.gettimeout.return_value = 0.0 | 
					
						
							|  |  |  |     return sock | 
					
						
							| 
									
										
										
										
											2015-01-14 00:19:09 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def force_legacy_ssl_support(): | 
					
						
							|  |  |  |     return mock.patch('asyncio.sslproto._is_sslproto_available', | 
					
						
							|  |  |  |                       return_value=False) |