asyncio: Sync with github repo

This commit is contained in:
Yury Selivanov 2015-05-11 13:48:16 -04:00
parent a032e46df6
commit 90ecfe65e6
5 changed files with 110 additions and 36 deletions

View file

@ -197,6 +197,7 @@ def __init__(self):
# exceed this duration in seconds, the slow callback/task is logged. # exceed this duration in seconds, the slow callback/task is logged.
self.slow_callback_duration = 0.1 self.slow_callback_duration = 0.1
self._current_handle = None self._current_handle = None
self._task_factory = None
def __repr__(self): def __repr__(self):
return ('<%s running=%s closed=%s debug=%s>' return ('<%s running=%s closed=%s debug=%s>'
@ -209,11 +210,32 @@ def create_task(self, coro):
Return a task object. Return a task object.
""" """
self._check_closed() self._check_closed()
task = tasks.Task(coro, loop=self) if self._task_factory is None:
if task._source_traceback: task = tasks.Task(coro, loop=self)
del task._source_traceback[-1] if task._source_traceback:
del task._source_traceback[-1]
else:
task = self._task_factory(self, coro)
return task return task
def set_task_factory(self, factory):
"""Set a task factory that will be used by loop.create_task().
If factory is None the default task factory will be set.
If factory is a callable, it should have a signature matching
'(loop, coro)', where 'loop' will be a reference to the active
event loop, 'coro' will be a coroutine object. The callable
must return a Future.
"""
if factory is not None and not callable(factory):
raise TypeError('task factory must be a callable or None')
self._task_factory = factory
def get_task_factory(self):
"""Return a task factory, or None if the default one is in use."""
return self._task_factory
def _make_socket_transport(self, sock, protocol, waiter=None, *, def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None): extra=None, server=None):
"""Create socket transport.""" """Create socket transport."""
@ -465,25 +487,25 @@ def call_soon_threadsafe(self, callback, *args):
self._write_to_self() self._write_to_self()
return handle return handle
def run_in_executor(self, executor, callback, *args): def run_in_executor(self, executor, func, *args):
if (coroutines.iscoroutine(callback) if (coroutines.iscoroutine(func)
or coroutines.iscoroutinefunction(callback)): or coroutines.iscoroutinefunction(func)):
raise TypeError("coroutines cannot be used with run_in_executor()") raise TypeError("coroutines cannot be used with run_in_executor()")
self._check_closed() self._check_closed()
if isinstance(callback, events.Handle): if isinstance(func, events.Handle):
assert not args assert not args
assert not isinstance(callback, events.TimerHandle) assert not isinstance(func, events.TimerHandle)
if callback._cancelled: if func._cancelled:
f = futures.Future(loop=self) f = futures.Future(loop=self)
f.set_result(None) f.set_result(None)
return f return f
callback, args = callback._callback, callback._args func, args = func._callback, func._args
if executor is None: if executor is None:
executor = self._default_executor executor = self._default_executor
if executor is None: if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
self._default_executor = executor self._default_executor = executor
return futures.wrap_future(executor.submit(callback, *args), loop=self) return futures.wrap_future(executor.submit(func, *args), loop=self)
def set_default_executor(self, executor): def set_default_executor(self, executor):
self._default_executor = executor self._default_executor = executor

View file

@ -277,7 +277,7 @@ def create_task(self, coro):
def call_soon_threadsafe(self, callback, *args): def call_soon_threadsafe(self, callback, *args):
raise NotImplementedError raise NotImplementedError
def run_in_executor(self, executor, callback, *args): def run_in_executor(self, executor, func, *args):
raise NotImplementedError raise NotImplementedError
def set_default_executor(self, executor): def set_default_executor(self, executor):
@ -438,6 +438,14 @@ def add_signal_handler(self, sig, callback, *args):
def remove_signal_handler(self, sig): def remove_signal_handler(self, sig):
raise NotImplementedError raise NotImplementedError
# Task factory.
def set_task_factory(self, factory):
raise NotImplementedError
def get_task_factory(self):
raise NotImplementedError
# Error handlers. # Error handlers.
def set_exception_handler(self, handler): def set_exception_handler(self, handler):

View file

@ -1,6 +1,7 @@
"""Queues""" """Queues"""
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] __all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty',
'JoinableQueue']
import collections import collections
import heapq import heapq
@ -286,3 +287,7 @@ def _put(self, item):
def _get(self): def _get(self):
return self._queue.pop() return self._queue.pop()
JoinableQueue = Queue
"""Deprecated alias for Queue."""

View file

@ -310,7 +310,10 @@ def _select(self, r, w, _, timeout=None):
def select(self, timeout=None): def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0) timeout = None if timeout is None else max(timeout, 0)
ready = [] ready = []
r, w, _ = self._select(self._readers, self._writers, [], timeout) try:
r, w, _ = self._select(self._readers, self._writers, [], timeout)
except InterruptedError:
return ready
r = set(r) r = set(r)
w = set(w) w = set(w)
for fd in r | w: for fd in r | w:
@ -359,10 +362,11 @@ def select(self, timeout=None):
# poll() has a resolution of 1 millisecond, round away from # poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds. # zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3) timeout = math.ceil(timeout * 1e3)
fd_event_list = self._poll.poll(timeout)
ready = [] ready = []
try:
fd_event_list = self._poll.poll(timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list: for fd, event in fd_event_list:
events = 0 events = 0
if event & ~select.POLLIN: if event & ~select.POLLIN:
@ -423,9 +427,11 @@ def select(self, timeout=None):
# FD is registered. # FD is registered.
max_ev = max(len(self._fd_to_key), 1) max_ev = max(len(self._fd_to_key), 1)
fd_event_list = self._epoll.poll(timeout, max_ev)
ready = [] ready = []
try:
fd_event_list = self._epoll.poll(timeout, max_ev)
except InterruptedError:
return ready
for fd, event in fd_event_list: for fd, event in fd_event_list:
events = 0 events = 0
if event & ~select.EPOLLIN: if event & ~select.EPOLLIN:
@ -439,10 +445,8 @@ def select(self, timeout=None):
return ready return ready
def close(self): def close(self):
try: self._epoll.close()
self._epoll.close() super().close()
finally:
super().close()
if hasattr(select, 'devpoll'): if hasattr(select, 'devpoll'):
@ -481,10 +485,11 @@ def select(self, timeout=None):
# devpoll() has a resolution of 1 millisecond, round away from # devpoll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds. # zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3) timeout = math.ceil(timeout * 1e3)
fd_event_list = self._devpoll.poll(timeout)
ready = [] ready = []
try:
fd_event_list = self._devpoll.poll(timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list: for fd, event in fd_event_list:
events = 0 events = 0
if event & ~select.POLLIN: if event & ~select.POLLIN:
@ -498,10 +503,8 @@ def select(self, timeout=None):
return ready return ready
def close(self): def close(self):
try: self._devpoll.close()
self._devpoll.close() super().close()
finally:
super().close()
if hasattr(select, 'kqueue'): if hasattr(select, 'kqueue'):
@ -552,9 +555,11 @@ def unregister(self, fileobj):
def select(self, timeout=None): def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0) timeout = None if timeout is None else max(timeout, 0)
max_ev = len(self._fd_to_key) max_ev = len(self._fd_to_key)
kev_list = self._kqueue.control(None, max_ev, timeout)
ready = [] ready = []
try:
kev_list = self._kqueue.control(None, max_ev, timeout)
except InterruptedError:
return ready
for kev in kev_list: for kev in kev_list:
fd = kev.ident fd = kev.ident
flag = kev.filter flag = kev.filter
@ -570,10 +575,8 @@ def select(self, timeout=None):
return ready return ready
def close(self): def close(self):
try: self._kqueue.close()
self._kqueue.close() super().close()
finally:
super().close()
# Choose the best implementation, roughly: # Choose the best implementation, roughly:

View file

@ -623,6 +623,42 @@ def custom_handler(loop, context):
self.assertIs(type(_context['context']['exception']), self.assertIs(type(_context['context']['exception']),
ZeroDivisionError) ZeroDivisionError)
def test_set_task_factory_invalid(self):
with self.assertRaisesRegex(
TypeError, 'task factory must be a callable or None'):
self.loop.set_task_factory(1)
self.assertIsNone(self.loop.get_task_factory())
def test_set_task_factory(self):
self.loop._process_events = mock.Mock()
class MyTask(asyncio.Task):
pass
@asyncio.coroutine
def coro():
pass
factory = lambda loop, coro: MyTask(coro, loop=loop)
self.assertIsNone(self.loop.get_task_factory())
self.loop.set_task_factory(factory)
self.assertIs(self.loop.get_task_factory(), factory)
task = self.loop.create_task(coro())
self.assertTrue(isinstance(task, MyTask))
self.loop.run_until_complete(task)
self.loop.set_task_factory(None)
self.assertIsNone(self.loop.get_task_factory())
task = self.loop.create_task(coro())
self.assertTrue(isinstance(task, asyncio.Task))
self.assertFalse(isinstance(task, MyTask))
self.loop.run_until_complete(task)
def test_env_var_debug(self): def test_env_var_debug(self):
code = '\n'.join(( code = '\n'.join((
'import asyncio', 'import asyncio',