mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	 f28896d0bb
			
		
	
	
		f28896d0bb
		
	
	
	
	
		
			
			calls (which don't exist on non-Unix platforms). Thanks to Trent Nelson for the report and patch.
		
			
				
	
	
		
			274 lines
		
	
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			274 lines
		
	
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Test suite for SocketServer.py.
 | |
| """
 | |
| 
 | |
| import contextlib
 | |
| import errno
 | |
| import imp
 | |
| import os
 | |
| import select
 | |
| import signal
 | |
| import socket
 | |
| import tempfile
 | |
| import threading
 | |
| import time
 | |
| import unittest
 | |
| import SocketServer
 | |
| 
 | |
| import test.test_support
 | |
| from test.test_support import reap_children, verbose, TestSkipped
 | |
| from test.test_support import TESTFN as TEST_FILE
 | |
| 
 | |
| test.test_support.requires("network")
 | |
| 
 | |
| NREQ = 3
 | |
| TEST_STR = "hello world\n"
 | |
| HOST = "localhost"
 | |
| 
 | |
| HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
 | |
| HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
 | |
| 
 | |
| def signal_alarm(n):
 | |
|     """Call signal.alarm when it exists (i.e. not on Windows)."""
 | |
|     if hasattr(signal, 'alarm'):
 | |
|         signal.alarm(n)
 | |
| 
 | |
| def receive(sock, n, timeout=20):
 | |
|     r, w, x = select.select([sock], [], [], timeout)
 | |
|     if sock in r:
 | |
|         return sock.recv(n)
 | |
|     else:
 | |
|         raise RuntimeError, "timed out on %r" % (sock,)
 | |
| 
 | |
| if HAVE_UNIX_SOCKETS:
 | |
|     class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
 | |
|                                   SocketServer.UnixStreamServer):
 | |
|         pass
 | |
| 
 | |
|     class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
 | |
|                                     SocketServer.UnixDatagramServer):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class MyMixinServer:
 | |
|     def serve_a_few(self):
 | |
|         for i in range(NREQ):
 | |
|             self.handle_request()
 | |
| 
 | |
|     def handle_error(self, request, client_address):
 | |
|         self.close_request(request)
 | |
|         self.server_close()
 | |
|         raise
 | |
| 
 | |
| 
 | |
| class ServerThread(threading.Thread):
 | |
|     def __init__(self, addr, svrcls, hdlrcls):
 | |
|         threading.Thread.__init__(self)
 | |
|         self.__addr = addr
 | |
|         self.__svrcls = svrcls
 | |
|         self.__hdlrcls = hdlrcls
 | |
|         self.ready = threading.Event()
 | |
| 
 | |
|     def run(self):
 | |
|         class svrcls(MyMixinServer, self.__svrcls):
 | |
|             pass
 | |
|         if verbose: print "thread: creating server"
 | |
|         svr = svrcls(self.__addr, self.__hdlrcls)
 | |
|         # We had the OS pick a port, so pull the real address out of
 | |
|         # the server.
 | |
|         self.addr = svr.server_address
 | |
|         self.port = self.addr[1]
 | |
|         if self.addr != svr.socket.getsockname():
 | |
|             raise RuntimeError('server_address was %s, expected %s' %
 | |
|                                (self.addr, svr.socket.getsockname()))
 | |
|         self.ready.set()
 | |
|         if verbose: print "thread: serving three times"
 | |
|         svr.serve_a_few()
 | |
|         if verbose: print "thread: done"
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def simple_subprocess(testcase):
 | |
|     pid = os.fork()
 | |
|     if pid == 0:
 | |
|         # Don't throw an exception; it would be caught by the test harness.
 | |
|         os._exit(72)
 | |
|     yield None
 | |
|     pid2, status = os.waitpid(pid, 0)
 | |
|     testcase.assertEquals(pid2, pid)
 | |
|     testcase.assertEquals(72 << 8, status)
 | |
| 
 | |
| 
 | |
| class SocketServerTest(unittest.TestCase):
 | |
|     """Test all socket servers."""
 | |
| 
 | |
|     def setUp(self):
 | |
|         signal_alarm(20)  # Kill deadlocks after 20 seconds.
 | |
|         self.port_seed = 0
 | |
|         self.test_files = []
 | |
| 
 | |
|     def tearDown(self):
 | |
|         reap_children()
 | |
| 
 | |
|         for fn in self.test_files:
 | |
|             try:
 | |
|                 os.remove(fn)
 | |
|             except os.error:
 | |
|                 pass
 | |
|         self.test_files[:] = []
 | |
|         signal_alarm(0)  # Didn't deadlock.
 | |
| 
 | |
|     def pickaddr(self, proto):
 | |
|         if proto == socket.AF_INET:
 | |
|             return (HOST, 0)
 | |
|         else:
 | |
|             # XXX: We need a way to tell AF_UNIX to pick its own name
 | |
|             # like AF_INET provides port==0.
 | |
|             dir = None
 | |
|             if os.name == 'os2':
 | |
|                 dir = '\socket'
 | |
|             fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
 | |
|             if os.name == 'os2':
 | |
|                 # AF_UNIX socket names on OS/2 require a specific prefix
 | |
|                 # which can't include a drive letter and must also use
 | |
|                 # backslashes as directory separators
 | |
|                 if fn[1] == ':':
 | |
|                     fn = fn[2:]
 | |
|                 if fn[0] in (os.sep, os.altsep):
 | |
|                     fn = fn[1:]
 | |
|                 if os.sep == '/':
 | |
|                     fn = fn.replace(os.sep, os.altsep)
 | |
|                 else:
 | |
|                     fn = fn.replace(os.altsep, os.sep)
 | |
|             self.test_files.append(fn)
 | |
|             return fn
 | |
| 
 | |
|     def run_server(self, svrcls, hdlrbase, testfunc):
 | |
|         class MyHandler(hdlrbase):
 | |
|             def handle(self):
 | |
|                 line = self.rfile.readline()
 | |
|                 self.wfile.write(line)
 | |
| 
 | |
|         addr = self.pickaddr(svrcls.address_family)
 | |
|         if verbose:
 | |
|             print "ADDR =", addr
 | |
|             print "CLASS =", svrcls
 | |
|         t = ServerThread(addr, svrcls, MyHandler)
 | |
|         if verbose: print "server created"
 | |
|         t.start()
 | |
|         if verbose: print "server running"
 | |
|         t.ready.wait(10)
 | |
|         self.assert_(t.ready.isSet(),
 | |
|                      "%s not ready within a reasonable time" % svrcls)
 | |
|         addr = t.addr
 | |
|         for i in range(NREQ):
 | |
|             if verbose: print "test client", i
 | |
|             testfunc(svrcls.address_family, addr)
 | |
|         if verbose: print "waiting for server"
 | |
|         t.join()
 | |
|         if verbose: print "done"
 | |
| 
 | |
|     def stream_examine(self, proto, addr):
 | |
|         s = socket.socket(proto, socket.SOCK_STREAM)
 | |
|         s.connect(addr)
 | |
|         s.sendall(TEST_STR)
 | |
|         buf = data = receive(s, 100)
 | |
|         while data and '\n' not in buf:
 | |
|             data = receive(s, 100)
 | |
|             buf += data
 | |
|         self.assertEquals(buf, TEST_STR)
 | |
|         s.close()
 | |
| 
 | |
|     def dgram_examine(self, proto, addr):
 | |
|         s = socket.socket(proto, socket.SOCK_DGRAM)
 | |
|         s.sendto(TEST_STR, addr)
 | |
|         buf = data = receive(s, 100)
 | |
|         while data and '\n' not in buf:
 | |
|             data = receive(s, 100)
 | |
|             buf += data
 | |
|         self.assertEquals(buf, TEST_STR)
 | |
|         s.close()
 | |
| 
 | |
|     def test_TCPServer(self):
 | |
|         self.run_server(SocketServer.TCPServer,
 | |
|                         SocketServer.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     def test_ThreadingTCPServer(self):
 | |
|         self.run_server(SocketServer.ThreadingTCPServer,
 | |
|                         SocketServer.StreamRequestHandler,
 | |
|                         self.stream_examine)
 | |
| 
 | |
|     if HAVE_FORKING:
 | |
|         def test_ForkingTCPServer(self):
 | |
|             with simple_subprocess(self):
 | |
|                 self.run_server(SocketServer.ForkingTCPServer,
 | |
|                                 SocketServer.StreamRequestHandler,
 | |
|                                 self.stream_examine)
 | |
| 
 | |
|     if HAVE_UNIX_SOCKETS:
 | |
|         def test_UnixStreamServer(self):
 | |
|             self.run_server(SocketServer.UnixStreamServer,
 | |
|                             SocketServer.StreamRequestHandler,
 | |
|                             self.stream_examine)
 | |
| 
 | |
|         def test_ThreadingUnixStreamServer(self):
 | |
|             self.run_server(SocketServer.ThreadingUnixStreamServer,
 | |
|                             SocketServer.StreamRequestHandler,
 | |
|                             self.stream_examine)
 | |
| 
 | |
|         if HAVE_FORKING:
 | |
|             def test_ForkingUnixStreamServer(self):
 | |
|                 with simple_subprocess(self):
 | |
|                     self.run_server(ForkingUnixStreamServer,
 | |
|                                     SocketServer.StreamRequestHandler,
 | |
|                                     self.stream_examine)
 | |
| 
 | |
|     def test_UDPServer(self):
 | |
|         self.run_server(SocketServer.UDPServer,
 | |
|                         SocketServer.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     def test_ThreadingUDPServer(self):
 | |
|         self.run_server(SocketServer.ThreadingUDPServer,
 | |
|                         SocketServer.DatagramRequestHandler,
 | |
|                         self.dgram_examine)
 | |
| 
 | |
|     if HAVE_FORKING:
 | |
|         def test_ForkingUDPServer(self):
 | |
|             with simple_subprocess(self):
 | |
|                 self.run_server(SocketServer.ForkingUDPServer,
 | |
|                                 SocketServer.DatagramRequestHandler,
 | |
|                                 self.dgram_examine)
 | |
| 
 | |
|     # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
 | |
|     # client address so this cannot work:
 | |
| 
 | |
|     # if HAVE_UNIX_SOCKETS:
 | |
|     #     def test_UnixDatagramServer(self):
 | |
|     #         self.run_server(SocketServer.UnixDatagramServer,
 | |
|     #                         SocketServer.DatagramRequestHandler,
 | |
|     #                         self.dgram_examine)
 | |
|     #
 | |
|     #     def test_ThreadingUnixDatagramServer(self):
 | |
|     #         self.run_server(SocketServer.ThreadingUnixDatagramServer,
 | |
|     #                         SocketServer.DatagramRequestHandler,
 | |
|     #                         self.dgram_examine)
 | |
|     #
 | |
|     #     if HAVE_FORKING:
 | |
|     #         def test_ForkingUnixDatagramServer(self):
 | |
|     #             self.run_server(SocketServer.ForkingUnixDatagramServer,
 | |
|     #                             SocketServer.DatagramRequestHandler,
 | |
|     #                             self.dgram_examine)
 | |
| 
 | |
| 
 | |
| def test_main():
 | |
|     if imp.lock_held():
 | |
|         # If the import lock is held, the threads will hang
 | |
|         raise TestSkipped("can't run when import lock is held")
 | |
| 
 | |
|     test.test_support.run_unittest(SocketServerTest)
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     test_main()
 | |
|     signal_alarm(3)  # Shutdown shouldn't take more than 3 seconds.
 |