diff --git a/asserver.py b/asserver.py index 2901f91..6a15d1d 100755 --- a/asserver.py +++ b/asserver.py @@ -31,16 +31,17 @@ class SSHServer(asyncssh.SSHServer): def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool: try: - return config_clients[username].validate(key, "", "") is not None + return config_clients[username].validate(key, "", "") is not None # checks client key except: return False def broadcast(msg: str, use_stderr: bool = False): + # Broadcast a message to all connected clients assert type(msg) == str msg = msg.strip("\r\n") if use_stderr: - msg += "\r\n" + msg += "\r\n" # we need CRLF for stderr apparently for c in connected_clients: c.stderr.write(msg) else: @@ -52,15 +53,18 @@ async def handle_connection(process: asyncssh.SSHServerProcess): connected_clients.append(process) username = process.get_extra_info("username") try: + # hello there connected_msg = f"[connected] {username}\n" stderr.write(connected_msg) broadcast(connected_msg, True) if process.command is not None: + # client has provided a command as a ssh commandline argument line = process.command.strip("\r\n") msg = f"{username}: {line}\n" stdout.write(msg) broadcast(msg) else: + # client wants an interactive session while True: try: async for line in process.stdin: @@ -70,18 +74,19 @@ async def handle_connection(process: asyncssh.SSHServerProcess): stdout.write(msg) broadcast(msg) except asyncssh.TerminalSizeChanged: - continue - finally: - break + continue # we don't want to exit when the client changes its terminal size lol + finally: # but otherwise, we do want to + break # exit this loop. except asyncssh.BreakReceived: - pass + pass # we don't want to write an error message on this exception except Exception as e: stderr.write(f"An error occured: {type(e).__name__} {e}\n") stderr.flush() finally: + # exit process, remove client from list, inform other clients + disconnected_msg = f"[disconnected] {username}\n" process.exit(0) connected_clients.remove(process) - disconnected_msg = f"[disconnected] {username}\n" stderr.write(disconnected_msg) broadcast(disconnected_msg, True) @@ -97,11 +102,12 @@ if __name__ == "__main__": config_host = str(config["host"]) config_port = int(config["port"]) config_private_key = asyncssh.import_private_key(args.pkey.read_text()) + for c in config["clients"]: + config_clients[str(c)] = asyncssh.import_authorized_keys(str(config["clients"][c])) + # read private key server_public_key = config_private_key.export_public_key("openssh").decode().strip("\n\r") stderr.write(f"Server public key is \"{server_public_key}\"\n") stderr.flush() - for c in config["clients"]: - config_clients[str(c)] = asyncssh.import_authorized_keys(str(config["clients"][c])) # start server loop = asyncio.get_event_loop() loop.run_until_complete(