mirror of
				https://github.com/python/cpython.git
				synced 2025-10-25 18:54:53 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			413 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			413 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Test suite for socketserver.
 | |
| """
 | |
| 
 | |
| import contextlib
 | |
| import os
 | |
| import select
 | |
| import signal
 | |
| import socket
 | |
| import tempfile
 | |
| import unittest
 | |
| import socketserver
 | |
| 
 | |
| import test.support
 | |
| from test.support import reap_children, reap_threads, verbose
 | |
| try:
 | |
|     import threading
 | |
| except ImportError:
 | |
|     threading = None
 | |
| 
 | |
| test.support.requires("network")
 | |
| 
 | |
| TEST_STR = b"hello world\n"
 | |
| HOST = test.support.HOST
 | |
| 
 | |
| HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
 | |
| requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
 | |
|                                             'requires Unix sockets')
 | |
| HAVE_FORKING = hasattr(os, "fork")
 | |
| requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
 | |
| 
 | |
| def signal_alarm(n):
 | |
|     """Call signal.alarm when it exists (i.e. not on Windows)."""
 | |
|     if hasattr(signal, 'alarm'):
 | |
|         signal.alarm(n)
 | |
| 
 | |
| # Remember real select() to avoid interferences with mocking
 | |
| _real_select = select.select
 | |
| 
 | |
| def receive(sock, n, timeout=20):
 | |
|     r, w, x = _real_select([sock], [], [], timeout)
 | |
|     if sock in r:
 | |
|         return sock.recv(n)
 | |
|     else:
 | |
|         raise RuntimeError("timed out on %r" % (sock,))
 | |
| 
 | |
| if HAVE_UNIX_SOCKETS:
 | |
|     class ForkingUnixStreamServer(socketserver.ForkingMixIn,
 | |
|                                   socketserver.UnixStreamServer):
 | |
|         pass
 | |
| 
 | |
|     class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
 | |
|                                     socketserver.UnixDatagramServer):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def simple_subprocess(testcase):
 | |
|     """Tests that a custom child process is not waited on (Issue 1540386)"""
 | |
|     pid = os.fork()
 | |
|     if pid == 0:
 | |
|         # Don't raise an exception; it would be caught by the test harness.
 | |
|         os._exit(72)
 | |
|     yield None
 | |
|     pid2, status = os.waitpid(pid, 0)
 | |
|     testcase.assertEqual(pid2, pid)
 | |
|     testcase.assertEqual(72 << 8, status)
 | |
| 
 | |
| 
 | |
| @unittest.skipUnless(threading, 'Threading required for this test.')
 | |
| class SocketServerTest(unittest.TestCase):
 | |
|     """Test all socket servers."""
 | |
| 
 | |
|     def setUp(self):
 | |
|         signal_alarm(60)  # Kill deadlocks after 60 seconds.
 | |
|         self.port_seed = 0
 | |
|         self.test_files = []
 | |
| 
 | |
|     def tearDown(self):
 | |
|         signal_alarm(0)  # Didn't deadlock.
 | |
|         reap_children()
 | |
| 
 | |
|         for fn in self.test_files:
 | |
|             try:
 | |
|                 os.remove(fn)
 | |
|             except OSError:
 | |
|                 pass
 | |
|         self.test_files[:] = []
 | |
| 
 | |
|     def pickaddr(self, proto):
 | |
|         if proto == socket.AF_INET:
 | |
|             return (HOST, 0)
 | |
|         else:
 | |
|             # XXX: We need a way to tell AF_UNIX to pick its own name
 | |
|             # like AF_INET provides port==0.
 | |
|             dir = None
 | |
|             fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
 | |
|             self.test_files.append(fn)
 | |
|             return fn
 | |
| 
 | |
|     def make_server(self, addr, svrcls, hdlrbase):
 | |
|         class MyServer(svrcls):
 | |
|             def handle_error(self, request, client_address):
 | |
|                 self.close_request(request)
 | |
|                 raise
 | |
| 
 | |
|         class MyHandler(hdlrbase):
 | |
|             def handle(self):
 | |
|                 line = self.rfile.readline()
 | |
|                 self.wfile.write(line)
 | |
| 
 | |
|         if verbose: print("creating server")
 | |
|         server = MyServer(addr, MyHandler)
 | |
|         self.assertEqual(server.server_address, server.socket.getsockname())
 | |
|         return server
 | |
| 
 | |
|     @reap_threads
 | |
|     def run_server(self, svrcls, hdlrbase, testfunc):
 | |
|         server = self.make_server(self.pickaddr(svrcls.address_family),
 | |
|                                   svrcls, hdlrbase)
 | |
|         # We had the OS pick a port, so pull the real address out of
 | |
|         # the server.
 | |
|         addr = server.server_address
 | |
|         if verbose:
 | |
|             print("ADDR =", addr)
 | |
|             print("CLASS =", svrcls)
 | |
| 
 | |
|         t = threading.Thread(
 | |
|             name='%s serving' % svrcls,
 | |
|             target=server.serve_forever,
 | |
|             # Short poll interval to make the test finish quickly.
 | |
|             # Time between requests is short enough that we won't wake
 | |
|             # up spuriously too many times.
 | |
|             kwargs={'poll_interval':0.01})
 | |
|         t.daemon = True  # In case this function raises.
 | |
|         t.start()
 | |
|         if verbose: print("server running")
 | |
|         for i in range(3):
 | |
|             if verbose: print("test client", i)
 | |
|             testfunc(svrcls.address_family, addr)
 | |
|         if verbose: print("waiting for server")
 | |
|         server.shutdown()
 | |
|         t.join()
 | |
|         server.server_close()
 | |
|         self.assertEqual(-1, server.socket.fileno())
 | |
|         if verbose: print("done")
 | |
| 
 | |
|     def stream_examine(self, proto, addr):
 | |
|         s = socket.socket(proto, socket.SOCK_STREAM)
 | |
|         s.connect(addr)
 | |
|         s.sendall(TEST_STR)
 | |
|         buf = data = receive(s, 100)
 | |
|         while data and b'\n' not in buf:
 | |
|             data = receive(s, 100)
 | |
|             buf += data
 | |
|         self.assertEqual(buf, TEST_STR)
 | |
|         s.close()
 | |
| 
 | |
|     def dgram_examine(self, proto, addr):
 | |
|         s = socket.socket(proto, socket.SOCK_DGRAM)
 | |
|         if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
 | |
|             s.bind(self.pickaddr(proto))
 | |
|         s.sendto(TEST_STR, addr)
 | |
|         buf = data = receive(s, 100)
 | |
|         while data and b'\n' not in buf:
 | |
|             data = receive(s, 100)
 | |
|             buf += data
 | |
|         self.assertEqual(buf, TEST_STR)
 | |
|         s.close()
 | |
| 
 | |
|     def test_TCPServer(self):
 | |
|         self.run_server(socketserver.TCPServer,
 | |
|                         socketserver.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     def test_ThreadingTCPServer(self):
 | |
|         self.run_server(socketserver.ThreadingTCPServer,
 | |
|                         socketserver.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     @requires_forking
 | |
|     def test_ForkingTCPServer(self):
 | |
|         with simple_subprocess(self):
 | |
|             self.run_server(socketserver.ForkingTCPServer,
 | |
|                             socketserver.StreamRequestHandler,
 | |
|                             self.stream_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     def test_UnixStreamServer(self):
 | |
|         self.run_server(socketserver.UnixStreamServer,
 | |
|                         socketserver.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     def test_ThreadingUnixStreamServer(self):
 | |
|         self.run_server(socketserver.ThreadingUnixStreamServer,
 | |
|                         socketserver.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     @requires_forking
 | |
|     def test_ForkingUnixStreamServer(self):
 | |
|         with simple_subprocess(self):
 | |
|             self.run_server(ForkingUnixStreamServer,
 | |
|                             socketserver.StreamRequestHandler,
 | |
|                             self.stream_examine)
 | |
| 
 | |
|     def test_UDPServer(self):
 | |
|         self.run_server(socketserver.UDPServer,
 | |
|                         socketserver.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     def test_ThreadingUDPServer(self):
 | |
|         self.run_server(socketserver.ThreadingUDPServer,
 | |
|                         socketserver.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     @requires_forking
 | |
|     def test_ForkingUDPServer(self):
 | |
|         with simple_subprocess(self):
 | |
|             self.run_server(socketserver.ForkingUDPServer,
 | |
|                             socketserver.DatagramRequestHandler,
 | |
|                             self.dgram_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     def test_UnixDatagramServer(self):
 | |
|         self.run_server(socketserver.UnixDatagramServer,
 | |
|                         socketserver.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     def test_ThreadingUnixDatagramServer(self):
 | |
|         self.run_server(socketserver.ThreadingUnixDatagramServer,
 | |
|                         socketserver.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     @requires_unix_sockets
 | |
|     @requires_forking
 | |
|     def test_ForkingUnixDatagramServer(self):
 | |
|         self.run_server(ForkingUnixDatagramServer,
 | |
|                         socketserver.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     @reap_threads
 | |
|     def test_shutdown(self):
 | |
|         # Issue #2302: shutdown() should always succeed in making an
 | |
|         # other thread leave serve_forever().
 | |
|         class MyServer(socketserver.TCPServer):
 | |
|             pass
 | |
| 
 | |
|         class MyHandler(socketserver.StreamRequestHandler):
 | |
|             pass
 | |
| 
 | |
|         threads = []
 | |
|         for i in range(20):
 | |
|             s = MyServer((HOST, 0), MyHandler)
 | |
|             t = threading.Thread(
 | |
|                 name='MyServer serving',
 | |
|                 target=s.serve_forever,
 | |
|                 kwargs={'poll_interval':0.01})
 | |
|             t.daemon = True  # In case this function raises.
 | |
|             threads.append((t, s))
 | |
|         for t, s in threads:
 | |
|             t.start()
 | |
|             s.shutdown()
 | |
|         for t, s in threads:
 | |
|             t.join()
 | |
|             s.server_close()
 | |
| 
 | |
|     def test_tcpserver_bind_leak(self):
 | |
|         # Issue #22435: the server socket wouldn't be closed if bind()/listen()
 | |
|         # failed.
 | |
|         # Create many servers for which bind() will fail, to see if this result
 | |
|         # in FD exhaustion.
 | |
|         for i in range(1024):
 | |
|             with self.assertRaises(OverflowError):
 | |
|                 socketserver.TCPServer((HOST, -1),
 | |
|                                        socketserver.StreamRequestHandler)
 | |
| 
 | |
|     def test_context_manager(self):
 | |
|         with socketserver.TCPServer((HOST, 0),
 | |
|                                     socketserver.StreamRequestHandler) as server:
 | |
|             pass
 | |
|         self.assertEqual(-1, server.socket.fileno())
 | |
| 
 | |
| 
 | |
| class ErrorHandlerTest(unittest.TestCase):
 | |
|     """Test that the servers pass normal exceptions from the handler to
 | |
|     handle_error(), and that exiting exceptions like SystemExit and
 | |
|     KeyboardInterrupt are not passed."""
 | |
| 
 | |
|     def tearDown(self):
 | |
|         test.support.unlink(test.support.TESTFN)
 | |
| 
 | |
|     def test_sync_handled(self):
 | |
|         BaseErrorTestServer(ValueError)
 | |
|         self.check_result(handled=True)
 | |
| 
 | |
|     def test_sync_not_handled(self):
 | |
|         with self.assertRaises(SystemExit):
 | |
|             BaseErrorTestServer(SystemExit)
 | |
|         self.check_result(handled=False)
 | |
| 
 | |
|     @unittest.skipUnless(threading, 'Threading required for this test.')
 | |
|     def test_threading_handled(self):
 | |
|         ThreadingErrorTestServer(ValueError)
 | |
|         self.check_result(handled=True)
 | |
| 
 | |
|     @unittest.skipUnless(threading, 'Threading required for this test.')
 | |
|     def test_threading_not_handled(self):
 | |
|         ThreadingErrorTestServer(SystemExit)
 | |
|         self.check_result(handled=False)
 | |
| 
 | |
|     @requires_forking
 | |
|     def test_forking_handled(self):
 | |
|         ForkingErrorTestServer(ValueError)
 | |
|         self.check_result(handled=True)
 | |
| 
 | |
|     @requires_forking
 | |
|     def test_forking_not_handled(self):
 | |
|         ForkingErrorTestServer(SystemExit)
 | |
|         self.check_result(handled=False)
 | |
| 
 | |
|     def check_result(self, handled):
 | |
|         with open(test.support.TESTFN) as log:
 | |
|             expected = 'Handler called\n' + 'Error handled\n' * handled
 | |
|             self.assertEqual(log.read(), expected)
 | |
| 
 | |
| 
 | |
| class BaseErrorTestServer(socketserver.TCPServer):
 | |
|     def __init__(self, exception):
 | |
|         self.exception = exception
 | |
|         super().__init__((HOST, 0), BadHandler)
 | |
|         with socket.create_connection(self.server_address):
 | |
|             pass
 | |
|         try:
 | |
|             self.handle_request()
 | |
|         finally:
 | |
|             self.server_close()
 | |
|         self.wait_done()
 | |
| 
 | |
|     def handle_error(self, request, client_address):
 | |
|         with open(test.support.TESTFN, 'a') as log:
 | |
|             log.write('Error handled\n')
 | |
| 
 | |
|     def wait_done(self):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class BadHandler(socketserver.BaseRequestHandler):
 | |
|     def handle(self):
 | |
|         with open(test.support.TESTFN, 'a') as log:
 | |
|             log.write('Handler called\n')
 | |
|         raise self.server.exception('Test error')
 | |
| 
 | |
| 
 | |
| class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
 | |
|         BaseErrorTestServer):
 | |
|     def __init__(self, *pos, **kw):
 | |
|         self.done = threading.Event()
 | |
|         super().__init__(*pos, **kw)
 | |
| 
 | |
|     def shutdown_request(self, *pos, **kw):
 | |
|         super().shutdown_request(*pos, **kw)
 | |
|         self.done.set()
 | |
| 
 | |
|     def wait_done(self):
 | |
|         self.done.wait()
 | |
| 
 | |
| 
 | |
| class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
 | |
|     def wait_done(self):
 | |
|         [child] = self.active_children
 | |
|         os.waitpid(child, 0)
 | |
|         self.active_children.clear()
 | |
| 
 | |
| 
 | |
| class MiscTestCase(unittest.TestCase):
 | |
| 
 | |
|     def test_all(self):
 | |
|         # objects defined in the module should be in __all__
 | |
|         expected = []
 | |
|         for name in dir(socketserver):
 | |
|             if not name.startswith('_'):
 | |
|                 mod_object = getattr(socketserver, name)
 | |
|                 if getattr(mod_object, '__module__', None) == 'socketserver':
 | |
|                     expected.append(name)
 | |
|         self.assertCountEqual(socketserver.__all__, expected)
 | |
| 
 | |
|     def test_shutdown_request_called_if_verify_request_false(self):
 | |
|         # Issue #26309: BaseServer should call shutdown_request even if
 | |
|         # verify_request is False
 | |
| 
 | |
|         class MyServer(socketserver.TCPServer):
 | |
|             def verify_request(self, request, client_address):
 | |
|                 return False
 | |
| 
 | |
|             shutdown_called = 0
 | |
|             def shutdown_request(self, request):
 | |
|                 self.shutdown_called += 1
 | |
|                 socketserver.TCPServer.shutdown_request(self, request)
 | |
| 
 | |
|         server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
 | |
|         s = socket.socket(server.address_family, socket.SOCK_STREAM)
 | |
|         s.connect(server.server_address)
 | |
|         s.close()
 | |
|         server.handle_request()
 | |
|         self.assertEqual(server.shutdown_called, 1)
 | |
|         server.server_close()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
