| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | # Test suite for SocketServer.py | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2002-07-23 19:04:11 +00:00
										 |  |  | from test import test_support | 
					
						
							|  |  |  | from test.test_support import verbose, verify, TESTFN, TestSkipped | 
					
						
							| 
									
										
										
										
											2001-09-18 02:18:57 +00:00
										 |  |  | test_support.requires('network') | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | from SocketServer import * | 
					
						
							|  |  |  | import socket | 
					
						
							|  |  |  | import select | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import threading | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | NREQ = 3 | 
					
						
							|  |  |  | DELAY = 0.5 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyMixinHandler: | 
					
						
							|  |  |  |     def handle(self): | 
					
						
							|  |  |  |         time.sleep(DELAY) | 
					
						
							|  |  |  |         line = self.rfile.readline() | 
					
						
							|  |  |  |         time.sleep(DELAY) | 
					
						
							|  |  |  |         self.wfile.write(line) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyStreamHandler(MyMixinHandler, StreamRequestHandler): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyDatagramHandler(MyMixinHandler, DatagramRequestHandler): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyMixinServer: | 
					
						
							|  |  |  |     def serve_a_few(self): | 
					
						
							|  |  |  |         for i in range(NREQ): | 
					
						
							|  |  |  |             self.handle_request() | 
					
						
							|  |  |  |     def handle_error(self, request, client_address): | 
					
						
							|  |  |  |         self.close_request(request) | 
					
						
							|  |  |  |         self.server_close() | 
					
						
							|  |  |  |         raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | teststring = "hello world\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def receive(sock, n, timeout=20): | 
					
						
							|  |  |  |     r, w, x = select.select([sock], [], [], timeout) | 
					
						
							|  |  |  |     if sock in r: | 
					
						
							|  |  |  |         return sock.recv(n) | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2004-02-12 17:35:32 +00:00
										 |  |  |         raise RuntimeError, "timed out on %r" % (sock,) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | def testdgram(proto, addr): | 
					
						
							|  |  |  |     s = socket.socket(proto, socket.SOCK_DGRAM) | 
					
						
							|  |  |  |     s.sendto(teststring, addr) | 
					
						
							|  |  |  |     buf = data = receive(s, 100) | 
					
						
							|  |  |  |     while data and '\n' not in buf: | 
					
						
							|  |  |  |         data = receive(s, 100) | 
					
						
							|  |  |  |         buf += data | 
					
						
							|  |  |  |     verify(buf == teststring) | 
					
						
							|  |  |  |     s.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def teststream(proto, addr): | 
					
						
							|  |  |  |     s = socket.socket(proto, socket.SOCK_STREAM) | 
					
						
							|  |  |  |     s.connect(addr) | 
					
						
							| 
									
										
										
										
											2001-10-29 07:18:02 +00:00
										 |  |  |     s.sendall(teststring) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |     buf = data = receive(s, 100) | 
					
						
							|  |  |  |     while data and '\n' not in buf: | 
					
						
							|  |  |  |         data = receive(s, 100) | 
					
						
							|  |  |  |         buf += data | 
					
						
							|  |  |  |     verify(buf == teststring) | 
					
						
							|  |  |  |     s.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ServerThread(threading.Thread): | 
					
						
							|  |  |  |     def __init__(self, addr, svrcls, hdlrcls): | 
					
						
							|  |  |  |         threading.Thread.__init__(self) | 
					
						
							|  |  |  |         self.__addr = addr | 
					
						
							|  |  |  |         self.__svrcls = svrcls | 
					
						
							|  |  |  |         self.__hdlrcls = hdlrcls | 
					
						
							|  |  |  |     def run(self): | 
					
						
							|  |  |  |         class svrcls(MyMixinServer, self.__svrcls): | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  |         if verbose: print "thread: creating server" | 
					
						
							|  |  |  |         svr = svrcls(self.__addr, self.__hdlrcls) | 
					
						
							|  |  |  |         if verbose: print "thread: serving three times" | 
					
						
							|  |  |  |         svr.serve_a_few() | 
					
						
							|  |  |  |         if verbose: print "thread: done" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | seed = 0 | 
					
						
							|  |  |  | def pickport(): | 
					
						
							|  |  |  |     global seed | 
					
						
							|  |  |  |     seed += 1 | 
					
						
							|  |  |  |     return 10000 + (os.getpid() % 1000)*10 + seed | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2001-07-10 15:46:34 +00:00
										 |  |  | host = "localhost" | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | testfiles = [] | 
					
						
							|  |  |  | def pickaddr(proto): | 
					
						
							|  |  |  |     if proto == socket.AF_INET: | 
					
						
							|  |  |  |         return (host, pickport()) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         fn = TESTFN + str(pickport()) | 
					
						
							| 
									
										
										
										
											2004-04-11 12:03:57 +00:00
										 |  |  |         if os.name == 'os2': | 
					
						
							|  |  |  |             # AF_UNIX socket names on OS/2 require a specific prefix | 
					
						
							|  |  |  |             # which can't include a drive letter and must also use | 
					
						
							|  |  |  |             # backslashes as directory separators | 
					
						
							|  |  |  |             if fn[1] == ':': | 
					
						
							|  |  |  |                 fn = fn[2:] | 
					
						
							|  |  |  |             if fn[0] in (os.sep, os.altsep): | 
					
						
							|  |  |  |                 fn = fn[1:] | 
					
						
							|  |  |  |             fn = os.path.join('\socket', fn) | 
					
						
							|  |  |  |             if os.sep == '/': | 
					
						
							|  |  |  |                 fn = fn.replace(os.sep, os.altsep) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 fn = fn.replace(os.altsep, os.sep) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |         testfiles.append(fn) | 
					
						
							|  |  |  |         return fn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def cleanup(): | 
					
						
							|  |  |  |     for fn in testfiles: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             os.remove(fn) | 
					
						
							|  |  |  |         except os.error: | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  |     testfiles[:] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def testloop(proto, servers, hdlrcls, testfunc): | 
					
						
							|  |  |  |     for svrcls in servers: | 
					
						
							|  |  |  |         addr = pickaddr(proto) | 
					
						
							|  |  |  |         if verbose: | 
					
						
							|  |  |  |             print "ADDR =", addr | 
					
						
							|  |  |  |             print "CLASS =", svrcls | 
					
						
							|  |  |  |         t = ServerThread(addr, svrcls, hdlrcls) | 
					
						
							|  |  |  |         if verbose: print "server created" | 
					
						
							|  |  |  |         t.start() | 
					
						
							|  |  |  |         if verbose: print "server running" | 
					
						
							|  |  |  |         for i in range(NREQ): | 
					
						
							|  |  |  |             time.sleep(DELAY) | 
					
						
							|  |  |  |             if verbose: print "test client", i | 
					
						
							|  |  |  |             testfunc(proto, addr) | 
					
						
							|  |  |  |         if verbose: print "waiting for server" | 
					
						
							|  |  |  |         t.join() | 
					
						
							|  |  |  |         if verbose: print "done" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | tcpservers = [TCPServer, ThreadingTCPServer] | 
					
						
							| 
									
										
										
										
											2003-01-02 12:49:00 +00:00
										 |  |  | if hasattr(os, 'fork') and os.name not in ('os2',): | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |     tcpservers.append(ForkingTCPServer) | 
					
						
							|  |  |  | udpservers = [UDPServer, ThreadingUDPServer] | 
					
						
							| 
									
										
										
										
											2003-01-02 12:49:00 +00:00
										 |  |  | if hasattr(os, 'fork') and os.name not in ('os2',): | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |     udpservers.append(ForkingUDPServer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if not hasattr(socket, 'AF_UNIX'): | 
					
						
							|  |  |  |     streamservers = [] | 
					
						
							|  |  |  |     dgramservers = [] | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     class ForkingUnixStreamServer(ForkingMixIn, UnixStreamServer): pass | 
					
						
							| 
									
										
										
										
											2004-04-11 12:03:57 +00:00
										 |  |  |     streamservers = [UnixStreamServer, ThreadingUnixStreamServer] | 
					
						
							|  |  |  |     if hasattr(os, 'fork') and os.name not in ('os2',): | 
					
						
							|  |  |  |         streamservers.append(ForkingUnixStreamServer) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |     class ForkingUnixDatagramServer(ForkingMixIn, UnixDatagramServer): pass | 
					
						
							| 
									
										
										
										
											2004-04-11 12:03:57 +00:00
										 |  |  |     dgramservers = [UnixDatagramServer, ThreadingUnixDatagramServer] | 
					
						
							|  |  |  |     if hasattr(os, 'fork') and os.name not in ('os2',): | 
					
						
							|  |  |  |         dgramservers.append(ForkingUnixDatagramServer) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | def testall(): | 
					
						
							|  |  |  |     testloop(socket.AF_INET, tcpservers, MyStreamHandler, teststream) | 
					
						
							|  |  |  |     testloop(socket.AF_INET, udpservers, MyDatagramHandler, testdgram) | 
					
						
							| 
									
										
										
										
											2001-07-10 15:46:34 +00:00
										 |  |  |     if hasattr(socket, 'AF_UNIX'): | 
					
						
							|  |  |  |         testloop(socket.AF_UNIX, streamservers, MyStreamHandler, teststream) | 
					
						
							|  |  |  |         # Alas, on Linux (at least) recvfrom() doesn't return a meaningful | 
					
						
							|  |  |  |         # client address so this cannot work: | 
					
						
							|  |  |  |         ##testloop(socket.AF_UNIX, dgramservers, MyDatagramHandler, testdgram) | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2001-09-17 23:56:20 +00:00
										 |  |  | def test_main(): | 
					
						
							|  |  |  |     import imp | 
					
						
							|  |  |  |     if imp.lock_held(): | 
					
						
							|  |  |  |         # If the import lock is held, the threads will hang. | 
					
						
							|  |  |  |         raise TestSkipped("can't run when import lock is held") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2001-07-10 11:52:38 +00:00
										 |  |  |     try: | 
					
						
							|  |  |  |         testall() | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         cleanup() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2001-09-17 23:56:20 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     test_main() |