Refactored some of the server code

This commit is contained in:
ChaoticByte 2024-06-23 14:13:43 +02:00
parent 4e2e322066
commit ff4860f630
No known key found for this signature in database

View file

@ -8,7 +8,6 @@ import asyncio
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
from pathlib import Path from pathlib import Path
from sys import stdout
from sys import stderr from sys import stderr
import asyncssh import asyncssh
@ -38,63 +37,50 @@ class SSHServer(asyncssh.SSHServer):
return False return False
def broadcast(msg: str, use_stderr: bool = False): def broadcast(msg: str):
# Broadcast a message to all connected clients # Broadcast a message to all connected clients
assert type(msg) == str assert type(msg) == str
msg = msg.strip("\r\n") msg = msg.strip("\r\n")
if use_stderr:
msg += "\r\n" # we need CRLF for stderr apparently
for c in connected_clients:
c.stderr.write(msg)
else:
msg += "\n" msg += "\n"
if enable_logging:
stderr.write(msg)
for c in connected_clients: for c in connected_clients:
c.stdout.write(msg) c.stdout.write(msg)
def cleanup(process: asyncssh.SSHServerProcess, username: str):
disconnected_msg = f"[disconnected] {username}"
process.exit(0)
connected_clients.remove(process)
broadcast(disconnected_msg)
async def handle_connection(process: asyncssh.SSHServerProcess): async def handle_connection(process: asyncssh.SSHServerProcess):
connected_clients.append(process) connected_clients.append(process)
username = process.get_extra_info("username") username = process.get_extra_info("username")
try:
# hello there # hello there
connected_msg = f"[connected] {username}\n" connected_msg = f"[connected] {username}"
if enable_logging: broadcast(connected_msg)
stderr.write(connected_msg)
broadcast(connected_msg, True)
if process.command is not None: if process.command is not None:
# client has provided a command as a ssh commandline argument # client has provided a command as a ssh commandline argument
line = process.command.strip("\r\n") line = process.command.strip("\r\n")
msg = f"{username}: {line}\n" msg = f"{username}: {line}\n"
if enable_logging:
stdout.write(msg)
broadcast(msg) broadcast(msg)
cleanup(process, username)
else: else:
# client wants an interactive session async def listen():
while True:
try: try:
async for line in process.stdin: async for line in process.stdin:
if line == "": raise asyncssh.BreakReceived(0)
line = line.strip('\r\n') line = line.strip('\r\n')
msg = f"{username}: {line}\n" msg = f"{username}: {line}"
if enable_logging:
stdout.write(msg)
broadcast(msg) broadcast(msg)
except asyncssh.TerminalSizeChanged: except asyncssh.TerminalSizeChanged:
continue # we don't want to exit when the client changes its terminal size lol await listen() # we don't want to exit yet.
finally: # but otherwise, we do want to
break # exit this loop.
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass # we don't want to write an error message on this exception pass # we don't want to write an error message on this exception
except Exception as e: except Exception as e:
stderr.write(f"An error occured: {type(e).__name__} {e}\n") stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stderr.flush() stderr.flush()
finally: await listen()
# exit process, remove client from list, inform other clients cleanup(process, username)
disconnected_msg = f"[disconnected] {username}\n"
process.exit(0)
connected_clients.remove(process)
if enable_logging:
stderr.write(disconnected_msg)
broadcast(disconnected_msg, True)
if __name__ == "__main__": if __name__ == "__main__":
@ -102,7 +88,7 @@ if __name__ == "__main__":
argp = ArgumentParser() argp = ArgumentParser()
argp.add_argument("config", type=Path, help="The path to the config file") argp.add_argument("config", type=Path, help="The path to the config file")
argp.add_argument("pkey", type=Path, help="The path to the ssh private key") argp.add_argument("pkey", type=Path, help="The path to the ssh private key")
argp.add_argument("--log", action="store_true", help="Enable logging to stdout and stderr") argp.add_argument("--log", action="store_true", help="Enable logging")
args = argp.parse_args() args = argp.parse_args()
# read config # read config
config = yaml.safe_load(args.config.read_text()) config = yaml.safe_load(args.config.read_text())