Refactored some of the server code

This commit is contained in:
ChaoticByte 2024-06-23 14:13:43 +02:00
parent 4e2e322066
commit ff4860f630
No known key found for this signature in database

View file

@ -8,7 +8,6 @@ import asyncio
from argparse import ArgumentParser
from getpass import getpass
from pathlib import Path
from sys import stdout
from sys import stderr
import asyncssh
@ -38,63 +37,50 @@ class SSHServer(asyncssh.SSHServer):
return False
def broadcast(msg: str, use_stderr: bool = False):
def broadcast(msg: str):
# Broadcast a message to all connected clients
assert type(msg) == str
msg = msg.strip("\r\n")
if use_stderr:
msg += "\r\n" # we need CRLF for stderr apparently
for c in connected_clients:
c.stderr.write(msg)
else:
msg += "\n"
for c in connected_clients:
c.stdout.write(msg)
msg += "\n"
if enable_logging:
stderr.write(msg)
for c in connected_clients:
c.stdout.write(msg)
def cleanup(process: asyncssh.SSHServerProcess, username: str):
disconnected_msg = f"[disconnected] {username}"
process.exit(0)
connected_clients.remove(process)
broadcast(disconnected_msg)
async def handle_connection(process: asyncssh.SSHServerProcess):
connected_clients.append(process)
username = process.get_extra_info("username")
try:
# hello there
connected_msg = f"[connected] {username}\n"
if enable_logging:
stderr.write(connected_msg)
broadcast(connected_msg, True)
if process.command is not None:
# client has provided a command as a ssh commandline argument
line = process.command.strip("\r\n")
msg = f"{username}: {line}\n"
if enable_logging:
stdout.write(msg)
broadcast(msg)
else:
# client wants an interactive session
while True:
try:
async for line in process.stdin:
if line == "": raise asyncssh.BreakReceived(0)
line = line.strip('\r\n')
msg = f"{username}: {line}\n"
if enable_logging:
stdout.write(msg)
broadcast(msg)
except asyncssh.TerminalSizeChanged:
continue # we don't want to exit when the client changes its terminal size lol
finally: # but otherwise, we do want to
break # exit this loop.
except asyncssh.BreakReceived:
pass # we don't want to write an error message on this exception
except Exception as e:
stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stderr.flush()
finally:
# exit process, remove client from list, inform other clients
disconnected_msg = f"[disconnected] {username}\n"
process.exit(0)
connected_clients.remove(process)
if enable_logging:
stderr.write(disconnected_msg)
broadcast(disconnected_msg, True)
# hello there
connected_msg = f"[connected] {username}"
broadcast(connected_msg)
if process.command is not None:
# client has provided a command as a ssh commandline argument
line = process.command.strip("\r\n")
msg = f"{username}: {line}\n"
broadcast(msg)
cleanup(process, username)
else:
async def listen():
try:
async for line in process.stdin:
line = line.strip('\r\n')
msg = f"{username}: {line}"
broadcast(msg)
except asyncssh.TerminalSizeChanged:
await listen() # we don't want to exit yet.
except asyncssh.BreakReceived:
pass # we don't want to write an error message on this exception
except Exception as e:
stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stderr.flush()
await listen()
cleanup(process, username)
if __name__ == "__main__":
@ -102,7 +88,7 @@ if __name__ == "__main__":
argp = ArgumentParser()
argp.add_argument("config", type=Path, help="The path to the config file")
argp.add_argument("pkey", type=Path, help="The path to the ssh private key")
argp.add_argument("--log", action="store_true", help="Enable logging to stdout and stderr")
argp.add_argument("--log", action="store_true", help="Enable logging")
args = argp.parse_args()
# read config
config = yaml.safe_load(args.config.read_text())