Refactored some of the server code

This commit is contained in:
ChaoticByte 2024-06-23 14:13:43 +02:00
parent 4e2e322066
commit ff4860f630
No known key found for this signature in database

View file

@ -8,7 +8,6 @@ import asyncio
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
from pathlib import Path from pathlib import Path
from sys import stdout
from sys import stderr from sys import stderr
import asyncssh import asyncssh
@ -38,63 +37,50 @@ class SSHServer(asyncssh.SSHServer):
return False return False
def broadcast(msg: str, use_stderr: bool = False): def broadcast(msg: str):
# Broadcast a message to all connected clients # Broadcast a message to all connected clients
assert type(msg) == str assert type(msg) == str
msg = msg.strip("\r\n") msg = msg.strip("\r\n")
if use_stderr: msg += "\n"
msg += "\r\n" # we need CRLF for stderr apparently if enable_logging:
for c in connected_clients: stderr.write(msg)
c.stderr.write(msg) for c in connected_clients:
else: c.stdout.write(msg)
msg += "\n"
for c in connected_clients: def cleanup(process: asyncssh.SSHServerProcess, username: str):
c.stdout.write(msg) disconnected_msg = f"[disconnected] {username}"
process.exit(0)
connected_clients.remove(process)
broadcast(disconnected_msg)
async def handle_connection(process: asyncssh.SSHServerProcess): async def handle_connection(process: asyncssh.SSHServerProcess):
connected_clients.append(process) connected_clients.append(process)
username = process.get_extra_info("username") username = process.get_extra_info("username")
try: # hello there
# hello there connected_msg = f"[connected] {username}"
connected_msg = f"[connected] {username}\n" broadcast(connected_msg)
if enable_logging: if process.command is not None:
stderr.write(connected_msg) # client has provided a command as a ssh commandline argument
broadcast(connected_msg, True) line = process.command.strip("\r\n")
if process.command is not None: msg = f"{username}: {line}\n"
# client has provided a command as a ssh commandline argument broadcast(msg)
line = process.command.strip("\r\n") cleanup(process, username)
msg = f"{username}: {line}\n" else:
if enable_logging: async def listen():
stdout.write(msg) try:
broadcast(msg) async for line in process.stdin:
else: line = line.strip('\r\n')
# client wants an interactive session msg = f"{username}: {line}"
while True: broadcast(msg)
try: except asyncssh.TerminalSizeChanged:
async for line in process.stdin: await listen() # we don't want to exit yet.
if line == "": raise asyncssh.BreakReceived(0) except asyncssh.BreakReceived:
line = line.strip('\r\n') pass # we don't want to write an error message on this exception
msg = f"{username}: {line}\n" except Exception as e:
if enable_logging: stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stdout.write(msg) stderr.flush()
broadcast(msg) await listen()
except asyncssh.TerminalSizeChanged: cleanup(process, username)
continue # we don't want to exit when the client changes its terminal size lol
finally: # but otherwise, we do want to
break # exit this loop.
except asyncssh.BreakReceived:
pass # we don't want to write an error message on this exception
except Exception as e:
stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stderr.flush()
finally:
# exit process, remove client from list, inform other clients
disconnected_msg = f"[disconnected] {username}\n"
process.exit(0)
connected_clients.remove(process)
if enable_logging:
stderr.write(disconnected_msg)
broadcast(disconnected_msg, True)
if __name__ == "__main__": if __name__ == "__main__":
@ -102,7 +88,7 @@ if __name__ == "__main__":
argp = ArgumentParser() argp = ArgumentParser()
argp.add_argument("config", type=Path, help="The path to the config file") argp.add_argument("config", type=Path, help="The path to the config file")
argp.add_argument("pkey", type=Path, help="The path to the ssh private key") argp.add_argument("pkey", type=Path, help="The path to the ssh private key")
argp.add_argument("--log", action="store_true", help="Enable logging to stdout and stderr") argp.add_argument("--log", action="store_true", help="Enable logging")
args = argp.parse_args() args = argp.parse_args()
# read config # read config
config = yaml.safe_load(args.config.read_text()) config = yaml.safe_load(args.config.read_text())