mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			279 lines
		
	
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			279 lines
		
	
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Utilities shared by tests."""
 | 
						|
 | 
						|
import collections
 | 
						|
import contextlib
 | 
						|
import io
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import threading
 | 
						|
import time
 | 
						|
import unittest
 | 
						|
import unittest.mock
 | 
						|
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
 | 
						|
try:
 | 
						|
    import ssl
 | 
						|
except ImportError:  # pragma: no cover
 | 
						|
    ssl = None
 | 
						|
 | 
						|
from . import tasks
 | 
						|
from . import base_events
 | 
						|
from . import events
 | 
						|
from . import selectors
 | 
						|
 | 
						|
 | 
						|
if sys.platform == 'win32':  # pragma: no cover
 | 
						|
    from .windows_utils import socketpair
 | 
						|
else:
 | 
						|
    from socket import socketpair  # pragma: no cover
 | 
						|
 | 
						|
 | 
						|
def dummy_ssl_context():
 | 
						|
    if ssl is None:
 | 
						|
        return None
 | 
						|
    else:
 | 
						|
        return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
 | 
						|
 | 
						|
 | 
						|
def run_briefly(loop):
 | 
						|
    @tasks.coroutine
 | 
						|
    def once():
 | 
						|
        pass
 | 
						|
    gen = once()
 | 
						|
    t = tasks.Task(gen, loop=loop)
 | 
						|
    try:
 | 
						|
        loop.run_until_complete(t)
 | 
						|
    finally:
 | 
						|
        gen.close()
 | 
						|
 | 
						|
 | 
						|
def run_until(loop, pred, timeout=None):
 | 
						|
    if timeout is not None:
 | 
						|
        deadline = time.time() + timeout
 | 
						|
    while not pred():
 | 
						|
        if timeout is not None:
 | 
						|
            timeout = deadline - time.time()
 | 
						|
            if timeout <= 0:
 | 
						|
                return False
 | 
						|
            loop.run_until_complete(tasks.sleep(timeout, loop=loop))
 | 
						|
        else:
 | 
						|
            run_briefly(loop)
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def run_once(loop):
 | 
						|
    """loop.stop() schedules _raise_stop_error()
 | 
						|
    and run_forever() runs until _raise_stop_error() callback.
 | 
						|
    this wont work if test waits for some IO events, because
 | 
						|
    _raise_stop_error() runs before any of io events callbacks.
 | 
						|
    """
 | 
						|
    loop.stop()
 | 
						|
    loop.run_forever()
 | 
						|
 | 
						|
 | 
						|
@contextlib.contextmanager
 | 
						|
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
 | 
						|
 | 
						|
    class SilentWSGIRequestHandler(WSGIRequestHandler):
 | 
						|
        def get_stderr(self):
 | 
						|
            return io.StringIO()
 | 
						|
 | 
						|
        def log_message(self, format, *args):
 | 
						|
            pass
 | 
						|
 | 
						|
    class SilentWSGIServer(WSGIServer):
 | 
						|
        def handle_error(self, request, client_address):
 | 
						|
            pass
 | 
						|
 | 
						|
    class SSLWSGIServer(SilentWSGIServer):
 | 
						|
        def finish_request(self, request, client_address):
 | 
						|
            # The relative location of our test directory (which
 | 
						|
            # contains the ssl key and certificate files) differs
 | 
						|
            # between the stdlib and stand-alone asyncio.
 | 
						|
            # Prefer our own if we can find it.
 | 
						|
            here = os.path.join(os.path.dirname(__file__), '..', 'tests')
 | 
						|
            if not os.path.isdir(here):
 | 
						|
                here = os.path.join(os.path.dirname(os.__file__),
 | 
						|
                                    'test', 'test_asyncio')
 | 
						|
            keyfile = os.path.join(here, 'ssl_key.pem')
 | 
						|
            certfile = os.path.join(here, 'ssl_cert.pem')
 | 
						|
            ssock = ssl.wrap_socket(request,
 | 
						|
                                    keyfile=keyfile,
 | 
						|
                                    certfile=certfile,
 | 
						|
                                    server_side=True)
 | 
						|
            try:
 | 
						|
                self.RequestHandlerClass(ssock, client_address, self)
 | 
						|
                ssock.close()
 | 
						|
            except OSError:
 | 
						|
                # maybe socket has been closed by peer
 | 
						|
                pass
 | 
						|
 | 
						|
    def app(environ, start_response):
 | 
						|
        status = '200 OK'
 | 
						|
        headers = [('Content-type', 'text/plain')]
 | 
						|
        start_response(status, headers)
 | 
						|
        return [b'Test message']
 | 
						|
 | 
						|
    # Run the test WSGI server in a separate thread in order not to
 | 
						|
    # interfere with event handling in the main thread
 | 
						|
    server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
 | 
						|
    httpd = make_server(host, port, app,
 | 
						|
                        server_class, SilentWSGIRequestHandler)
 | 
						|
    httpd.address = httpd.server_address
 | 
						|
    server_thread = threading.Thread(target=httpd.serve_forever)
 | 
						|
    server_thread.start()
 | 
						|
    try:
 | 
						|
        yield httpd
 | 
						|
    finally:
 | 
						|
        httpd.shutdown()
 | 
						|
        httpd.server_close()
 | 
						|
        server_thread.join()
 | 
						|
 | 
						|
 | 
						|
def make_test_protocol(base):
 | 
						|
    dct = {}
 | 
						|
    for name in dir(base):
 | 
						|
        if name.startswith('__') and name.endswith('__'):
 | 
						|
            # skip magic names
 | 
						|
            continue
 | 
						|
        dct[name] = unittest.mock.Mock(return_value=None)
 | 
						|
    return type('TestProtocol', (base,) + base.__bases__, dct)()
 | 
						|
 | 
						|
 | 
						|
class TestSelector(selectors.BaseSelector):
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self.keys = {}
 | 
						|
 | 
						|
    @property
 | 
						|
    def resolution(self):
 | 
						|
        return 1e-3
 | 
						|
 | 
						|
    def register(self, fileobj, events, data=None):
 | 
						|
        key = selectors.SelectorKey(fileobj, 0, events, data)
 | 
						|
        self.keys[fileobj] = key
 | 
						|
        return key
 | 
						|
 | 
						|
    def unregister(self, fileobj):
 | 
						|
        return self.keys.pop(fileobj)
 | 
						|
 | 
						|
    def select(self, timeout):
 | 
						|
        return []
 | 
						|
 | 
						|
    def get_map(self):
 | 
						|
        return self.keys
 | 
						|
 | 
						|
 | 
						|
class TestLoop(base_events.BaseEventLoop):
 | 
						|
    """Loop for unittests.
 | 
						|
 | 
						|
    It manages self time directly.
 | 
						|
    If something scheduled to be executed later then
 | 
						|
    on next loop iteration after all ready handlers done
 | 
						|
    generator passed to __init__ is calling.
 | 
						|
 | 
						|
    Generator should be like this:
 | 
						|
 | 
						|
        def gen():
 | 
						|
            ...
 | 
						|
            when = yield ...
 | 
						|
            ... = yield time_advance
 | 
						|
 | 
						|
    Value retuned by yield is absolute time of next scheduled handler.
 | 
						|
    Value passed to yield is time advance to move loop's time forward.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, gen=None):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        if gen is None:
 | 
						|
            def gen():
 | 
						|
                yield
 | 
						|
            self._check_on_close = False
 | 
						|
        else:
 | 
						|
            self._check_on_close = True
 | 
						|
 | 
						|
        self._gen = gen()
 | 
						|
        next(self._gen)
 | 
						|
        self._time = 0
 | 
						|
        self._timers = []
 | 
						|
        self._selector = TestSelector()
 | 
						|
 | 
						|
        self.readers = {}
 | 
						|
        self.writers = {}
 | 
						|
        self.reset_counters()
 | 
						|
 | 
						|
    def time(self):
 | 
						|
        return self._time
 | 
						|
 | 
						|
    def advance_time(self, advance):
 | 
						|
        """Move test time forward."""
 | 
						|
        if advance:
 | 
						|
            self._time += advance
 | 
						|
 | 
						|
    def close(self):
 | 
						|
        if self._check_on_close:
 | 
						|
            try:
 | 
						|
                self._gen.send(0)
 | 
						|
            except StopIteration:
 | 
						|
                pass
 | 
						|
            else:  # pragma: no cover
 | 
						|
                raise AssertionError("Time generator is not finished")
 | 
						|
 | 
						|
    def add_reader(self, fd, callback, *args):
 | 
						|
        self.readers[fd] = events.make_handle(callback, args)
 | 
						|
 | 
						|
    def remove_reader(self, fd):
 | 
						|
        self.remove_reader_count[fd] += 1
 | 
						|
        if fd in self.readers:
 | 
						|
            del self.readers[fd]
 | 
						|
            return True
 | 
						|
        else:
 | 
						|
            return False
 | 
						|
 | 
						|
    def assert_reader(self, fd, callback, *args):
 | 
						|
        assert fd in self.readers, 'fd {} is not registered'.format(fd)
 | 
						|
        handle = self.readers[fd]
 | 
						|
        assert handle._callback == callback, '{!r} != {!r}'.format(
 | 
						|
            handle._callback, callback)
 | 
						|
        assert handle._args == args, '{!r} != {!r}'.format(
 | 
						|
            handle._args, args)
 | 
						|
 | 
						|
    def add_writer(self, fd, callback, *args):
 | 
						|
        self.writers[fd] = events.make_handle(callback, args)
 | 
						|
 | 
						|
    def remove_writer(self, fd):
 | 
						|
        self.remove_writer_count[fd] += 1
 | 
						|
        if fd in self.writers:
 | 
						|
            del self.writers[fd]
 | 
						|
            return True
 | 
						|
        else:
 | 
						|
            return False
 | 
						|
 | 
						|
    def assert_writer(self, fd, callback, *args):
 | 
						|
        assert fd in self.writers, 'fd {} is not registered'.format(fd)
 | 
						|
        handle = self.writers[fd]
 | 
						|
        assert handle._callback == callback, '{!r} != {!r}'.format(
 | 
						|
            handle._callback, callback)
 | 
						|
        assert handle._args == args, '{!r} != {!r}'.format(
 | 
						|
            handle._args, args)
 | 
						|
 | 
						|
    def reset_counters(self):
 | 
						|
        self.remove_reader_count = collections.defaultdict(int)
 | 
						|
        self.remove_writer_count = collections.defaultdict(int)
 | 
						|
 | 
						|
    def _run_once(self):
 | 
						|
        super()._run_once()
 | 
						|
        for when in self._timers:
 | 
						|
            advance = self._gen.send(when)
 | 
						|
            self.advance_time(advance)
 | 
						|
        self._timers = []
 | 
						|
 | 
						|
    def call_at(self, when, callback, *args):
 | 
						|
        self._timers.append(when)
 | 
						|
        return super().call_at(when, callback, *args)
 | 
						|
 | 
						|
    def _process_events(self, event_list):
 | 
						|
        return
 | 
						|
 | 
						|
    def _write_to_self(self):
 | 
						|
        pass
 |