asyncio: Refactor tests: add a base TestCase class

This commit is contained in:
Victor Stinner 2014-06-18 01:36:32 +02:00
parent d6f02fc649
commit c73701de72
13 changed files with 145 additions and 219 deletions

View file

@ -11,6 +11,7 @@
import tempfile import tempfile
import threading import threading
import time import time
import unittest
from unittest import mock from unittest import mock
from http.server import HTTPServer from http.server import HTTPServer
@ -379,3 +380,20 @@ def get_function_source(func):
if source is None: if source is None:
raise ValueError("unable to get the source of %r" % (func,)) raise ValueError("unable to get the source of %r" % (func,))
return source return source
class TestCase(unittest.TestCase):
def set_event_loop(self, loop, *, cleanup=True):
assert loop is not None
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
if cleanup:
self.addCleanup(loop.close)
def new_test_loop(self, gen=None):
loop = TestLoop(gen)
self.set_event_loop(loop)
return loop
def tearDown(self):
events.set_event_loop(None)

View file

@ -19,12 +19,12 @@
PY34 = sys.version_info >= (3, 4) PY34 = sys.version_info >= (3, 4)
class BaseEventLoopTests(unittest.TestCase): class BaseEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = base_events.BaseEventLoop() self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock() self.loop._selector = mock.Mock()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def test_not_implemented(self): def test_not_implemented(self):
m = mock.Mock() m = mock.Mock()
@ -548,14 +548,11 @@ def connection_lost(self, exc):
self.done.set_result(None) self.done.set_result(None)
class BaseEventLoopWithSelectorTests(unittest.TestCase): class BaseEventLoopWithSelectorTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@mock.patch('asyncio.base_events.socket') @mock.patch('asyncio.base_events.socket')
def test_create_connection_multiple_errors(self, m_socket): def test_create_connection_multiple_errors(self, m_socket):

View file

@ -224,7 +224,7 @@ class EventLoopTestsMixin:
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.loop = self.create_event_loop() self.loop = self.create_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
@ -1629,14 +1629,14 @@ def connect(cmd=None, **kwds):
if sys.platform == 'win32': if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop() return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin, class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.ProactorEventLoop() return asyncio.ProactorEventLoop()
@ -1691,7 +1691,7 @@ def tearDown(self):
if hasattr(selectors, 'KqueueSelector'): if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(UnixEventLoopTestsMixin, class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop( return asyncio.SelectorEventLoop(
@ -1716,7 +1716,7 @@ def test_write_pty(self):
if hasattr(selectors, 'EpollSelector'): if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(UnixEventLoopTestsMixin, class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector()) return asyncio.SelectorEventLoop(selectors.EpollSelector())
@ -1724,7 +1724,7 @@ def create_event_loop(self):
if hasattr(selectors, 'PollSelector'): if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(UnixEventLoopTestsMixin, class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector()) return asyncio.SelectorEventLoop(selectors.PollSelector())
@ -1732,7 +1732,7 @@ def create_event_loop(self):
# Should always exist. # Should always exist.
class SelectEventLoopTests(UnixEventLoopTestsMixin, class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector()) return asyncio.SelectorEventLoop(selectors.SelectSelector())

View file

@ -13,14 +13,10 @@ def _fakefunc(f):
return f return f
class FutureTests(unittest.TestCase): class FutureTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_initial_state(self): def test_initial_state(self):
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
@ -30,12 +26,9 @@ def test_initial_state(self):
self.assertTrue(f.cancelled()) self.assertTrue(f.cancelled())
def test_init_constructor_default_loop(self): def test_init_constructor_default_loop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
f = asyncio.Future() f = asyncio.Future()
self.assertIs(f._loop, self.loop) self.assertIs(f._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_constructor_positional(self): def test_constructor_positional(self):
# Make sure Future doesn't accept a positional argument # Make sure Future doesn't accept a positional argument
@ -264,14 +257,10 @@ def test_wrap_future_cancel2(self):
self.assertTrue(f2.cancelled()) self.assertTrue(f2.cancelled())
class FutureDoneCallbackTests(unittest.TestCase): class FutureDoneCallbackTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def run_briefly(self): def run_briefly(self):
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)

View file

@ -17,14 +17,10 @@
RGX_REPR = re.compile(STR_RGX_REPR) RGX_REPR = re.compile(STR_RGX_REPR)
class LockTests(unittest.TestCase): class LockTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
@ -35,12 +31,9 @@ def test_ctor_loop(self):
self.assertIs(lock._loop, self.loop) self.assertIs(lock._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
lock = asyncio.Lock() lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop) self.assertIs(lock._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
lock = asyncio.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
@ -240,14 +233,10 @@ def test_context_manager_no_yield(self):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
class EventTests(unittest.TestCase): class EventTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
@ -258,12 +247,9 @@ def test_ctor_loop(self):
self.assertIs(ev._loop, self.loop) self.assertIs(ev._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
ev = asyncio.Event() ev = asyncio.Event()
self.assertIs(ev._loop, self.loop) self.assertIs(ev._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
ev = asyncio.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
@ -376,14 +362,10 @@ def c1(result):
self.assertTrue(t.result()) self.assertTrue(t.result())
class ConditionTests(unittest.TestCase): class ConditionTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
@ -394,12 +376,9 @@ def test_ctor_loop(self):
self.assertIs(cond._loop, self.loop) self.assertIs(cond._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
cond = asyncio.Condition() cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop) self.assertIs(cond._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_wait(self): def test_wait(self):
cond = asyncio.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
@ -678,14 +657,10 @@ def test_context_manager_no_yield(self):
self.assertFalse(cond.locked()) self.assertFalse(cond.locked())
class SemaphoreTests(unittest.TestCase): class SemaphoreTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
@ -696,12 +671,9 @@ def test_ctor_loop(self):
self.assertIs(sem._loop, self.loop) self.assertIs(sem._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
sem = asyncio.Semaphore() sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop) self.assertIs(sem._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_initial_value_zero(self): def test_initial_value_zero(self):
sem = asyncio.Semaphore(0, loop=self.loop) sem = asyncio.Semaphore(0, loop=self.loop)

View file

@ -12,10 +12,10 @@
from asyncio import test_utils from asyncio import test_utils
class ProactorSocketTransportTests(unittest.TestCase): class ProactorSocketTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.proactor = mock.Mock() self.proactor = mock.Mock()
self.loop._proactor = self.proactor self.loop._proactor = self.proactor
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
@ -343,7 +343,7 @@ def test_pause_resume_reading(self):
tr.close() tr.close()
class BaseProactorEventLoopTests(unittest.TestCase): class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
@ -356,6 +356,7 @@ def _socketpair(s):
return (self.ssock, self.csock) return (self.ssock, self.csock)
self.loop = EventLoop(self.proactor) self.loop = EventLoop(self.proactor)
self.set_event_loop(self.loop, cleanup=False)
@mock.patch.object(BaseProactorEventLoop, 'call_soon') @mock.patch.object(BaseProactorEventLoop, 'call_soon')
@mock.patch.object(BaseProactorEventLoop, '_socketpair') @mock.patch.object(BaseProactorEventLoop, '_socketpair')

View file

@ -7,14 +7,10 @@
from asyncio import test_utils from asyncio import test_utils
class _QueueTestBase(unittest.TestCase): class _QueueTestBase(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
class QueueBasicTests(_QueueTestBase): class QueueBasicTests(_QueueTestBase):
@ -32,8 +28,7 @@ def gen():
self.assertAlmostEqual(0.2, when) self.assertAlmostEqual(0.2, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
self.assertTrue(fn(q).startswith('<Queue'), fn(q)) self.assertTrue(fn(q).startswith('<Queue'), fn(q))
@ -80,12 +75,9 @@ def test_ctor_loop(self):
self.assertIs(q._loop, self.loop) self.assertIs(q._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
q = asyncio.Queue() q = asyncio.Queue()
self.assertIs(q._loop, self.loop) self.assertIs(q._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
self._test_repr_or_str(repr, True) self._test_repr_or_str(repr, True)
@ -126,8 +118,7 @@ def gen():
self.assertAlmostEqual(0.02, when) self.assertAlmostEqual(0.02, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(maxsize=2, loop=loop) q = asyncio.Queue(maxsize=2, loop=loop)
self.assertEqual(2, q.maxsize) self.assertEqual(2, q.maxsize)
@ -194,8 +185,7 @@ def gen():
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
started = asyncio.Event(loop=loop) started = asyncio.Event(loop=loop)
@ -241,8 +231,7 @@ def gen():
self.assertAlmostEqual(0.061, when) self.assertAlmostEqual(0.061, when)
yield 0.05 yield 0.05
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
@ -302,8 +291,7 @@ def gen():
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(maxsize=1, loop=loop) q = asyncio.Queue(maxsize=1, loop=loop)
started = asyncio.Event(loop=loop) started = asyncio.Event(loop=loop)

View file

@ -37,11 +37,12 @@ def list_to_buffer(l=()):
return bytearray().join(l) return bytearray().join(l)
class BaseSelectorEventLoopTests(unittest.TestCase): class BaseSelectorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
selector = mock.Mock() selector = mock.Mock()
self.loop = TestBaseSelectorEventLoop(selector) self.loop = TestBaseSelectorEventLoop(selector)
self.set_event_loop(self.loop, cleanup=False)
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = mock.Mock() m = mock.Mock()
@ -597,10 +598,10 @@ def test_process_events_write_cancelled(self):
self.loop.remove_writer.assert_called_with(1) self.loop.remove_writer.assert_called_with(1)
class SelectorTransportTests(unittest.TestCase): class SelectorTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
@ -684,14 +685,14 @@ def test_connection_lost(self):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(2, sys.getrefcount(self.loop), self.assertEqual(3, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
class SelectorSocketTransportTests(unittest.TestCase): class SelectorSocketTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7 self.sock_fd = self.sock.fileno.return_value = 7
@ -1061,10 +1062,10 @@ def test_write_eof_buffer(self):
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
class SelectorSslTransportTests(unittest.TestCase): class SelectorSslTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
@ -1396,10 +1397,10 @@ def test_ssl_transport_requires_ssl_module(self):
_SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
class SelectorDatagramTransportTests(unittest.TestCase): class SelectorDatagramTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = mock.Mock(spec_set=socket.socket) self.sock = mock.Mock(spec_set=socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7

View file

@ -15,13 +15,13 @@
from asyncio import test_utils from asyncio import test_utils
class StreamReaderTests(unittest.TestCase): class StreamReaderTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n' DATA = b'line1\nline2\nline3\n'
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
@ -29,6 +29,7 @@ def tearDown(self):
self.loop.close() self.loop.close()
gc.collect() gc.collect()
super().tearDown()
@mock.patch('asyncio.streams.events') @mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events): def test_ctor_global_loop(self, m_events):

View file

@ -1,4 +1,5 @@
from asyncio import subprocess from asyncio import subprocess
from asyncio import test_utils
import asyncio import asyncio
import signal import signal
import sys import sys
@ -151,21 +152,21 @@ def tearDown(self):
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
policy.set_child_watcher(None) policy.set_child_watcher(None)
self.loop.close() self.loop.close()
policy.set_event_loop(None) super().tearDown()
class SubprocessSafeWatcherTests(SubprocessWatcherMixin, class SubprocessSafeWatcherTests(SubprocessWatcherMixin,
unittest.TestCase): test_utils.TestCase):
Watcher = unix_events.SafeChildWatcher Watcher = unix_events.SafeChildWatcher
class SubprocessFastWatcherTests(SubprocessWatcherMixin, class SubprocessFastWatcherTests(SubprocessWatcherMixin,
unittest.TestCase): test_utils.TestCase):
Watcher = unix_events.FastChildWatcher Watcher = unix_events.FastChildWatcher
else: else:
# Windows # Windows
class SubprocessProactorTests(SubprocessMixin, unittest.TestCase): class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
def setUp(self): def setUp(self):
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
@ -178,6 +179,7 @@ def tearDown(self):
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
self.loop.close() self.loop.close()
policy.set_event_loop(None) policy.set_event_loop(None)
super().tearDown()
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -30,15 +30,10 @@ def __call__(self, *args):
pass pass
class TaskTests(unittest.TestCase): class TaskTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
gc.collect()
def test_task_class(self): def test_task_class(self):
@asyncio.coroutine @asyncio.coroutine
@ -51,6 +46,7 @@ def notmuch():
self.assertIs(t._loop, self.loop) self.assertIs(t._loop, self.loop)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.set_event_loop(loop)
t = asyncio.Task(notmuch(), loop=loop) t = asyncio.Task(notmuch(), loop=loop)
self.assertIs(t._loop, loop) self.assertIs(t._loop, loop)
loop.close() loop.close()
@ -66,6 +62,7 @@ def notmuch():
self.assertIs(t._loop, self.loop) self.assertIs(t._loop, self.loop)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.set_event_loop(loop)
t = asyncio.async(notmuch(), loop=loop) t = asyncio.async(notmuch(), loop=loop)
self.assertIs(t._loop, loop) self.assertIs(t._loop, loop)
loop.close() loop.close()
@ -81,6 +78,7 @@ def test_async_future(self):
self.assertIs(f, f_orig) self.assertIs(f, f_orig)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.set_event_loop(loop)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f = asyncio.async(f_orig, loop=loop) f = asyncio.async(f_orig, loop=loop)
@ -102,6 +100,7 @@ def notmuch():
self.assertIs(t, t_orig) self.assertIs(t, t_orig)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.set_event_loop(loop)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
t = asyncio.async(t_orig, loop=loop) t = asyncio.async(t_orig, loop=loop)
@ -220,8 +219,7 @@ def gen():
self.assertAlmostEqual(10.0, when) self.assertAlmostEqual(10.0, when)
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def task(): def task():
@ -346,7 +344,7 @@ def task():
def test_cancel_current_task(self): def test_cancel_current_task(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.addCleanup(loop.close) self.set_event_loop(loop)
@asyncio.coroutine @asyncio.coroutine
def task(): def task():
@ -374,8 +372,7 @@ def gen():
self.assertAlmostEqual(0.3, when) self.assertAlmostEqual(0.3, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
x = 0 x = 0
waiters = [] waiters = []
@ -410,8 +407,7 @@ def gen():
self.assertAlmostEqual(0.1, when) self.assertAlmostEqual(0.1, when)
when = yield 0.1 when = yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
foo_running = None foo_running = None
@ -436,8 +432,7 @@ def foo():
self.assertEqual(foo_running, False) self.assertEqual(foo_running, False)
def test_wait_for_blocking(self): def test_wait_for_blocking(self):
loop = test_utils.TestLoop() loop = self.new_test_loop()
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def coro(): def coro():
@ -457,8 +452,7 @@ def gen():
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def foo(): def foo():
@ -486,8 +480,7 @@ def gen():
self.assertAlmostEqual(0.15, when) self.assertAlmostEqual(0.15, when)
yield 0.15 yield 0.15
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@ -517,8 +510,7 @@ def gen():
self.assertAlmostEqual(0.015, when) self.assertAlmostEqual(0.015, when)
yield 0.015 yield 0.015
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop)
b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop)
@ -531,11 +523,8 @@ def foo():
return 42 return 42
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try:
res = loop.run_until_complete( res = loop.run_until_complete(
asyncio.Task(foo(), loop=loop)) asyncio.Task(foo(), loop=loop))
finally:
asyncio.set_event_loop(None)
self.assertEqual(res, 42) self.assertEqual(res, 42)
@ -573,8 +562,7 @@ def gen():
self.assertAlmostEqual(0.1, when) self.assertAlmostEqual(0.1, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
@ -629,8 +617,7 @@ def gen():
self.assertAlmostEqual(10.0, when) self.assertAlmostEqual(10.0, when)
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
# first_exception, task already has exception # first_exception, task already has exception
a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@ -663,8 +650,7 @@ def gen():
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
# first_exception, exception during waiting # first_exception, exception during waiting
a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@ -696,8 +682,7 @@ def gen():
self.assertAlmostEqual(0.15, when) self.assertAlmostEqual(0.15, when)
yield 0.15 yield 0.15
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
@ -733,8 +718,7 @@ def gen():
self.assertAlmostEqual(0.11, when) self.assertAlmostEqual(0.11, when)
yield 0.11 yield 0.11
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@ -764,8 +748,7 @@ def gen():
self.assertAlmostEqual(0.1, when) self.assertAlmostEqual(0.1, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@ -789,8 +772,7 @@ def gen():
yield 0.01 yield 0.01
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
completed = set() completed = set()
time_shifted = False time_shifted = False
@ -833,8 +815,7 @@ def gen():
yield 0 yield 0
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.sleep(0.1, 'a', loop=loop) a = asyncio.sleep(0.1, 'a', loop=loop)
b = asyncio.sleep(0.15, 'b', loop=loop) b = asyncio.sleep(0.15, 'b', loop=loop)
@ -870,8 +851,7 @@ def gen():
yield 0 yield 0
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.sleep(0.01, 'a', loop=loop) a = asyncio.sleep(0.01, 'a', loop=loop)
@ -890,8 +870,7 @@ def gen():
yield 0.05 yield 0.05
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.sleep(0.05, 'a', loop=loop) a = asyncio.sleep(0.05, 'a', loop=loop)
b = asyncio.sleep(0.10, 'b', loop=loop) b = asyncio.sleep(0.10, 'b', loop=loop)
@ -916,8 +895,7 @@ def gen():
self.assertAlmostEqual(0.05, when) self.assertAlmostEqual(0.05, when)
yield 0.05 yield 0.05
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
a = asyncio.sleep(0.05, 'a', loop=loop) a = asyncio.sleep(0.05, 'a', loop=loop)
b = asyncio.sleep(0.05, 'b', loop=loop) b = asyncio.sleep(0.05, 'b', loop=loop)
@ -958,8 +936,7 @@ def gen():
self.assertAlmostEqual(0.1, when) self.assertAlmostEqual(0.1, when)
yield 0.05 yield 0.05
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def sleeper(dt, arg): def sleeper(dt, arg):
@ -980,8 +957,7 @@ def gen():
self.assertAlmostEqual(10.0, when) self.assertAlmostEqual(10.0, when)
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop),
loop=loop) loop=loop)
@ -1012,8 +988,7 @@ def gen():
self.assertAlmostEqual(5000, when) self.assertAlmostEqual(5000, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def sleep(dt): def sleep(dt):
@ -1123,8 +1098,7 @@ def gen():
self.assertAlmostEqual(10.0, when) self.assertAlmostEqual(10.0, when)
yield 0 yield 0
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
@asyncio.coroutine @asyncio.coroutine
def sleeper(): def sleeper():
@ -1536,12 +1510,9 @@ def foo(): yield from []
class GatherTestsBase: class GatherTestsBase:
def setUp(self): def setUp(self):
self.one_loop = test_utils.TestLoop() self.one_loop = self.new_test_loop()
self.other_loop = test_utils.TestLoop() self.other_loop = self.new_test_loop()
self.set_event_loop(self.one_loop, cleanup=False)
def tearDown(self):
self.one_loop.close()
self.other_loop.close()
def _run_loop(self, loop): def _run_loop(self, loop):
while loop._ready: while loop._ready:
@ -1633,7 +1604,7 @@ def test_env_var_debug(self):
self.assertEqual(stdout.rstrip(), b'False') self.assertEqual(stdout.rstrip(), b'False')
class FutureGatherTests(GatherTestsBase, unittest.TestCase): class FutureGatherTests(GatherTestsBase, test_utils.TestCase):
def wrap_futures(self, *futures): def wrap_futures(self, *futures):
return futures return futures
@ -1717,16 +1688,12 @@ def test_result_exception_one_cancellation(self):
cb.assert_called_once_with(fut) cb.assert_called_once_with(fut)
class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
asyncio.set_event_loop(self.one_loop) asyncio.set_event_loop(self.one_loop)
def tearDown(self):
asyncio.set_event_loop(None)
super().tearDown()
def wrap_futures(self, *futures): def wrap_futures(self, *futures):
coros = [] coros = []
for fut in futures: for fut in futures:

View file

@ -29,14 +29,11 @@
@unittest.skipUnless(signal, 'Signals are not supported') @unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopSignalTests(unittest.TestCase): class SelectorEventLoopSignalTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_check_signal(self): def test_check_signal(self):
self.assertRaises( self.assertRaises(
@ -208,14 +205,11 @@ def test_close(self, m_signal):
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), @unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported') 'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase): class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self): def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path: with test_utils.unix_socket_path() as path:
@ -304,10 +298,10 @@ def test_create_unix_connection_ssl_noserverhost(self):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase): class UnixReadPipeTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
@ -451,7 +445,7 @@ def test__call_connection_lost(self):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self): def test__call_connection_lost_with_err(self):
@ -468,14 +462,14 @@ def test__call_connection_lost_with_err(self):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
class UnixWritePipeTransportTests(unittest.TestCase): class UnixWritePipeTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
@ -737,7 +731,7 @@ def test__call_connection_lost(self):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self): def test__call_connection_lost_with_err(self):
@ -753,7 +747,7 @@ def test__call_connection_lost_with_err(self):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test_close(self): def test_close(self):
@ -834,7 +828,7 @@ class ChildWatcherTestsMixin:
ignore_warnings = mock.patch.object(log.logger, "warning") ignore_warnings = mock.patch.object(log.logger, "warning")
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.running = False self.running = False
self.zombies = {} self.zombies = {}
@ -1392,7 +1386,7 @@ def test_set_loop(self, m):
# attach a new loop # attach a new loop
old_loop = self.loop old_loop = self.loop
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
patch = mock.patch.object patch = mock.patch.object
with patch(old_loop, "remove_signal_handler") as m_old_remove, \ with patch(old_loop, "remove_signal_handler") as m_old_remove, \
@ -1447,7 +1441,7 @@ def test_set_loop_race_condition(self, m):
self.assertFalse(callback3.called) self.assertFalse(callback3.called)
# attach a new loop # attach a new loop
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
with mock.patch.object( with mock.patch.object(
self.loop, "add_signal_handler") as m_add_signal_handler: self.loop, "add_signal_handler") as m_add_signal_handler:
@ -1505,12 +1499,12 @@ def test_close(self, m):
self.assertFalse(self.watcher._zombies) self.assertFalse(self.watcher._zombies)
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self): def create_watcher(self):
return asyncio.SafeChildWatcher() return asyncio.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self): def create_watcher(self):
return asyncio.FastChildWatcher() return asyncio.FastChildWatcher()

View file

@ -26,15 +26,11 @@ def data_received(self, data):
self.trans.close() self.trans.close()
class ProactorTests(unittest.TestCase): class ProactorTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.ProactorEventLoop() self.loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
self.loop = None
def test_close(self): def test_close(self):
a, b = self.loop._socketpair() a, b = self.loop._socketpair()