diff --git a/asserver.py b/asserver.py index 48d5a66..c1318b7 100755 --- a/asserver.py +++ b/asserver.py @@ -8,7 +8,6 @@ import asyncio from argparse import ArgumentParser from getpass import getpass from pathlib import Path -from sys import stdout from sys import stderr import asyncssh @@ -38,63 +37,50 @@ class SSHServer(asyncssh.SSHServer): return False -def broadcast(msg: str, use_stderr: bool = False): +def broadcast(msg: str): # Broadcast a message to all connected clients assert type(msg) == str msg = msg.strip("\r\n") - if use_stderr: - msg += "\r\n" # we need CRLF for stderr apparently - for c in connected_clients: - c.stderr.write(msg) - else: - msg += "\n" - for c in connected_clients: - c.stdout.write(msg) + msg += "\n" + if enable_logging: + stderr.write(msg) + for c in connected_clients: + c.stdout.write(msg) + +def cleanup(process: asyncssh.SSHServerProcess, username: str): + disconnected_msg = f"[disconnected] {username}" + process.exit(0) + connected_clients.remove(process) + broadcast(disconnected_msg) async def handle_connection(process: asyncssh.SSHServerProcess): connected_clients.append(process) username = process.get_extra_info("username") - try: - # hello there - connected_msg = f"[connected] {username}\n" - if enable_logging: - stderr.write(connected_msg) - broadcast(connected_msg, True) - if process.command is not None: - # client has provided a command as a ssh commandline argument - line = process.command.strip("\r\n") - msg = f"{username}: {line}\n" - if enable_logging: - stdout.write(msg) - broadcast(msg) - else: - # client wants an interactive session - while True: - try: - async for line in process.stdin: - if line == "": raise asyncssh.BreakReceived(0) - line = line.strip('\r\n') - msg = f"{username}: {line}\n" - if enable_logging: - stdout.write(msg) - broadcast(msg) - except asyncssh.TerminalSizeChanged: - continue # we don't want to exit when the client changes its terminal size lol - finally: # but otherwise, we do want to - break # exit this loop. - except asyncssh.BreakReceived: - pass # we don't want to write an error message on this exception - except Exception as e: - stderr.write(f"An error occured: {type(e).__name__} {e}\n") - stderr.flush() - finally: - # exit process, remove client from list, inform other clients - disconnected_msg = f"[disconnected] {username}\n" - process.exit(0) - connected_clients.remove(process) - if enable_logging: - stderr.write(disconnected_msg) - broadcast(disconnected_msg, True) + # hello there + connected_msg = f"[connected] {username}" + broadcast(connected_msg) + if process.command is not None: + # client has provided a command as a ssh commandline argument + line = process.command.strip("\r\n") + msg = f"{username}: {line}\n" + broadcast(msg) + cleanup(process, username) + else: + async def listen(): + try: + async for line in process.stdin: + line = line.strip('\r\n') + msg = f"{username}: {line}" + broadcast(msg) + except asyncssh.TerminalSizeChanged: + await listen() # we don't want to exit yet. + except asyncssh.BreakReceived: + pass # we don't want to write an error message on this exception + except Exception as e: + stderr.write(f"An error occured: {type(e).__name__} {e}\n") + stderr.flush() + await listen() + cleanup(process, username) if __name__ == "__main__": @@ -102,7 +88,7 @@ if __name__ == "__main__": argp = ArgumentParser() argp.add_argument("config", type=Path, help="The path to the config file") argp.add_argument("pkey", type=Path, help="The path to the ssh private key") - argp.add_argument("--log", action="store_true", help="Enable logging to stdout and stderr") + argp.add_argument("--log", action="store_true", help="Enable logging") args = argp.parse_args() # read config config = yaml.safe_load(args.config.read_text())