| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | import errno | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2013-09-05 20:46:49 +02:00
										 |  |  | import selectors | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | import signal | 
					
						
							|  |  |  | import socket | 
					
						
							|  |  |  | import struct | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import threading | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from . import connection | 
					
						
							|  |  |  | from . import process | 
					
						
							|  |  |  | from . import reduction | 
					
						
							| 
									
										
										
										
											2013-08-22 11:38:57 +01:00
										 |  |  | from . import semaphore_tracker | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | from . import spawn | 
					
						
							|  |  |  | from . import util | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', | 
					
						
							|  |  |  |            'set_forkserver_preload'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | MAXFDS_TO_SEND = 256 | 
					
						
							|  |  |  | UNSIGNED_STRUCT = struct.Struct('Q')     # large enough for pid_t | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  | # Forkserver class | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  | class ForkServer(object): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         self._forkserver_address = None | 
					
						
							|  |  |  |         self._forkserver_alive_fd = None | 
					
						
							|  |  |  |         self._inherited_fds = None | 
					
						
							|  |  |  |         self._lock = threading.Lock() | 
					
						
							|  |  |  |         self._preload_modules = ['__main__'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set_forkserver_preload(self, modules_names): | 
					
						
							|  |  |  |         '''Set list of module names to try to load in forkserver process.''' | 
					
						
							|  |  |  |         if not all(type(mod) is str for mod in self._preload_modules): | 
					
						
							|  |  |  |             raise TypeError('module_names must be a list of strings') | 
					
						
							|  |  |  |         self._preload_modules = modules_names | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_inherited_fds(self): | 
					
						
							|  |  |  |         '''Return list of fds inherited from parent process.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         This returns None if the current process was not started by fork | 
					
						
							|  |  |  |         server. | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         return self._inherited_fds | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connect_to_new_process(self, fds): | 
					
						
							|  |  |  |         '''Request forkserver to create a child process.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns a pair of fds (status_r, data_w).  The calling process can read | 
					
						
							|  |  |  |         the child process's pid and (eventually) its returncode from status_r. | 
					
						
							|  |  |  |         The calling process should write to data_w the pickled preparation and | 
					
						
							|  |  |  |         process data. | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         self.ensure_running() | 
					
						
							|  |  |  |         if len(fds) + 4 >= MAXFDS_TO_SEND: | 
					
						
							|  |  |  |             raise ValueError('too many fds') | 
					
						
							|  |  |  |         with socket.socket(socket.AF_UNIX) as client: | 
					
						
							|  |  |  |             client.connect(self._forkserver_address) | 
					
						
							|  |  |  |             parent_r, child_w = os.pipe() | 
					
						
							|  |  |  |             child_r, parent_w = os.pipe() | 
					
						
							|  |  |  |             allfds = [child_r, child_w, self._forkserver_alive_fd, | 
					
						
							|  |  |  |                       semaphore_tracker.getfd()] | 
					
						
							|  |  |  |             allfds += fds | 
					
						
							| 
									
										
										
										
											2013-08-22 11:38:55 +01:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  |                 reduction.sendfds(client, allfds) | 
					
						
							|  |  |  |                 return parent_r, parent_w | 
					
						
							| 
									
										
										
										
											2013-08-22 11:38:55 +01:00
										 |  |  |             except: | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  |                 os.close(parent_r) | 
					
						
							|  |  |  |                 os.close(parent_w) | 
					
						
							| 
									
										
										
										
											2013-08-22 11:38:55 +01:00
										 |  |  |                 raise | 
					
						
							|  |  |  |             finally: | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  |                 os.close(child_r) | 
					
						
							|  |  |  |                 os.close(child_w) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def ensure_running(self): | 
					
						
							|  |  |  |         '''Make sure that a fork server is running.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         This can be called from any process.  Note that usually a child | 
					
						
							|  |  |  |         process will just reuse the forkserver started by its parent, so | 
					
						
							|  |  |  |         ensure_running() will do nothing. | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         with self._lock: | 
					
						
							|  |  |  |             semaphore_tracker.ensure_running() | 
					
						
							|  |  |  |             if self._forkserver_alive_fd is not None: | 
					
						
							|  |  |  |                 return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             cmd = ('from multiprocessing.forkserver import main; ' + | 
					
						
							|  |  |  |                    'main(%d, %d, %r, **%r)') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if self._preload_modules: | 
					
						
							|  |  |  |                 desired_keys = {'main_path', 'sys_path'} | 
					
						
							|  |  |  |                 data = spawn.get_preparation_data('ignore') | 
					
						
							|  |  |  |                 data = dict((x,y) for (x,y) in data.items() | 
					
						
							|  |  |  |                             if x in desired_keys) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 data = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with socket.socket(socket.AF_UNIX) as listener: | 
					
						
							|  |  |  |                 address = connection.arbitrary_address('AF_UNIX') | 
					
						
							|  |  |  |                 listener.bind(address) | 
					
						
							|  |  |  |                 os.chmod(address, 0o600) | 
					
						
							|  |  |  |                 listener.listen(100) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # all client processes own the write end of the "alive" pipe; | 
					
						
							|  |  |  |                 # when they all terminate the read end becomes ready. | 
					
						
							|  |  |  |                 alive_r, alive_w = os.pipe() | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     fds_to_pass = [listener.fileno(), alive_r] | 
					
						
							|  |  |  |                     cmd %= (listener.fileno(), alive_r, self._preload_modules, | 
					
						
							|  |  |  |                             data) | 
					
						
							|  |  |  |                     exe = spawn.get_executable() | 
					
						
							|  |  |  |                     args = [exe] + util._args_from_interpreter_flags() | 
					
						
							|  |  |  |                     args += ['-c', cmd] | 
					
						
							|  |  |  |                     pid = util.spawnv_passfds(exe, args, fds_to_pass) | 
					
						
							|  |  |  |                 except: | 
					
						
							|  |  |  |                     os.close(alive_w) | 
					
						
							|  |  |  |                     raise | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     os.close(alive_r) | 
					
						
							|  |  |  |                 self._forkserver_address = address | 
					
						
							|  |  |  |                 self._forkserver_alive_fd = alive_w | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): | 
					
						
							|  |  |  |     '''Run forkserver.''' | 
					
						
							|  |  |  |     if preload: | 
					
						
							|  |  |  |         if '__main__' in preload and main_path is not None: | 
					
						
							|  |  |  |             process.current_process()._inheriting = True | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 spawn.import_main_path(main_path) | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 del process.current_process()._inheriting | 
					
						
							|  |  |  |         for modname in preload: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 __import__(modname) | 
					
						
							|  |  |  |             except ImportError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # close sys.stdin | 
					
						
							|  |  |  |     if sys.stdin is not None: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             sys.stdin.close() | 
					
						
							|  |  |  |             sys.stdin = open(os.devnull) | 
					
						
							|  |  |  |         except (OSError, ValueError): | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # ignoring SIGCHLD means no need to reap zombie processes | 
					
						
							|  |  |  |     handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) | 
					
						
							| 
									
										
										
										
											2013-09-05 20:46:49 +02:00
										 |  |  |     with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ | 
					
						
							|  |  |  |          selectors.DefaultSelector() as selector: | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  |         _forkserver._forkserver_address = listener.getsockname() | 
					
						
							| 
									
										
										
										
											2013-09-05 20:46:49 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         selector.register(listener, selectors.EVENT_READ) | 
					
						
							|  |  |  |         selector.register(alive_r, selectors.EVENT_READ) | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         while True: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2013-09-05 20:46:49 +02:00
										 |  |  |                 while True: | 
					
						
							|  |  |  |                     rfds = [key.fileobj for (key, events) in selector.select()] | 
					
						
							|  |  |  |                     if rfds: | 
					
						
							|  |  |  |                         break | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 if alive_r in rfds: | 
					
						
							|  |  |  |                     # EOF because no more client processes left | 
					
						
							|  |  |  |                     assert os.read(alive_r, 1) == b'' | 
					
						
							|  |  |  |                     raise SystemExit | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 assert listener in rfds | 
					
						
							|  |  |  |                 with listener.accept()[0] as s: | 
					
						
							|  |  |  |                     code = 1 | 
					
						
							|  |  |  |                     if os.fork() == 0: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             _serve_one(s, listener, alive_r, handler) | 
					
						
							|  |  |  |                         except Exception: | 
					
						
							|  |  |  |                             sys.excepthook(*sys.exc_info()) | 
					
						
							|  |  |  |                             sys.stderr.flush() | 
					
						
							|  |  |  |                         finally: | 
					
						
							|  |  |  |                             os._exit(code) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             except InterruptedError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             except OSError as e: | 
					
						
							|  |  |  |                 if e.errno != errno.ECONNABORTED: | 
					
						
							|  |  |  |                     raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _serve_one(s, listener, alive_r, handler): | 
					
						
							|  |  |  |     # close unnecessary stuff and reset SIGCHLD handler | 
					
						
							|  |  |  |     listener.close() | 
					
						
							|  |  |  |     os.close(alive_r) | 
					
						
							|  |  |  |     signal.signal(signal.SIGCHLD, handler) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # receive fds from parent process | 
					
						
							|  |  |  |     fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) | 
					
						
							|  |  |  |     s.close() | 
					
						
							|  |  |  |     assert len(fds) <= MAXFDS_TO_SEND | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  |     (child_r, child_w, _forkserver._forkserver_alive_fd, | 
					
						
							|  |  |  |      stfd, *_forkserver._inherited_fds) = fds | 
					
						
							|  |  |  |     semaphore_tracker._semaphore_tracker._fd = stfd | 
					
						
							| 
									
										
										
										
											2013-08-14 15:35:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # send pid to client processes | 
					
						
							|  |  |  |     write_unsigned(child_w, os.getpid()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # reseed random number generator | 
					
						
							|  |  |  |     if 'random' in sys.modules: | 
					
						
							|  |  |  |         import random | 
					
						
							|  |  |  |         random.seed() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # run process object received over pipe | 
					
						
							|  |  |  |     code = spawn._main(child_r) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # write the exit code to the pipe | 
					
						
							|  |  |  |     write_unsigned(child_w, code) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Read and write unsigned numbers | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def read_unsigned(fd): | 
					
						
							|  |  |  |     data = b'' | 
					
						
							|  |  |  |     length = UNSIGNED_STRUCT.size | 
					
						
							|  |  |  |     while len(data) < length: | 
					
						
							|  |  |  |         while True: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 s = os.read(fd, length - len(data)) | 
					
						
							|  |  |  |             except InterruptedError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |         if not s: | 
					
						
							|  |  |  |             raise EOFError('unexpected EOF') | 
					
						
							|  |  |  |         data += s | 
					
						
							|  |  |  |     return UNSIGNED_STRUCT.unpack(data)[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def write_unsigned(fd, n): | 
					
						
							|  |  |  |     msg = UNSIGNED_STRUCT.pack(n) | 
					
						
							|  |  |  |     while msg: | 
					
						
							|  |  |  |         while True: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 nbytes = os.write(fd, msg) | 
					
						
							|  |  |  |             except InterruptedError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |         if nbytes == 0: | 
					
						
							|  |  |  |             raise RuntimeError('should not get here') | 
					
						
							|  |  |  |         msg = msg[nbytes:] | 
					
						
							| 
									
										
										
										
											2013-10-16 16:41:56 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | _forkserver = ForkServer() | 
					
						
							|  |  |  | ensure_running = _forkserver.ensure_running | 
					
						
							|  |  |  | get_inherited_fds = _forkserver.get_inherited_fds | 
					
						
							|  |  |  | connect_to_new_process = _forkserver.connect_to_new_process | 
					
						
							|  |  |  | set_forkserver_preload = _forkserver.set_forkserver_preload |