Added more comments and moved the read-privatekey-part after the read-config-part

This commit is contained in:
ChaoticByte 2024-06-16 09:07:57 +02:00
parent aca4f8c0be
commit 8193410226
No known key found for this signature in database

View file

@ -31,16 +31,17 @@ class SSHServer(asyncssh.SSHServer):
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool: def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
try: try:
return config_clients[username].validate(key, "", "") is not None return config_clients[username].validate(key, "", "") is not None # checks client key
except: except:
return False return False
def broadcast(msg: str, use_stderr: bool = False): def broadcast(msg: str, use_stderr: bool = False):
# Broadcast a message to all connected clients
assert type(msg) == str assert type(msg) == str
msg = msg.strip("\r\n") msg = msg.strip("\r\n")
if use_stderr: if use_stderr:
msg += "\r\n" msg += "\r\n" # we need CRLF for stderr apparently
for c in connected_clients: for c in connected_clients:
c.stderr.write(msg) c.stderr.write(msg)
else: else:
@ -52,15 +53,18 @@ async def handle_connection(process: asyncssh.SSHServerProcess):
connected_clients.append(process) connected_clients.append(process)
username = process.get_extra_info("username") username = process.get_extra_info("username")
try: try:
# hello there
connected_msg = f"[connected] {username}\n" connected_msg = f"[connected] {username}\n"
stderr.write(connected_msg) stderr.write(connected_msg)
broadcast(connected_msg, True) broadcast(connected_msg, True)
if process.command is not None: if process.command is not None:
# client has provided a command as a ssh commandline argument
line = process.command.strip("\r\n") line = process.command.strip("\r\n")
msg = f"{username}: {line}\n" msg = f"{username}: {line}\n"
stdout.write(msg) stdout.write(msg)
broadcast(msg) broadcast(msg)
else: else:
# client wants an interactive session
while True: while True:
try: try:
async for line in process.stdin: async for line in process.stdin:
@ -70,18 +74,19 @@ async def handle_connection(process: asyncssh.SSHServerProcess):
stdout.write(msg) stdout.write(msg)
broadcast(msg) broadcast(msg)
except asyncssh.TerminalSizeChanged: except asyncssh.TerminalSizeChanged:
continue continue # we don't want to exit when the client changes its terminal size lol
finally: finally: # but otherwise, we do want to
break break # exit this loop.
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass # we don't want to write an error message on this exception
except Exception as e: except Exception as e:
stderr.write(f"An error occured: {type(e).__name__} {e}\n") stderr.write(f"An error occured: {type(e).__name__} {e}\n")
stderr.flush() stderr.flush()
finally: finally:
# exit process, remove client from list, inform other clients
disconnected_msg = f"[disconnected] {username}\n"
process.exit(0) process.exit(0)
connected_clients.remove(process) connected_clients.remove(process)
disconnected_msg = f"[disconnected] {username}\n"
stderr.write(disconnected_msg) stderr.write(disconnected_msg)
broadcast(disconnected_msg, True) broadcast(disconnected_msg, True)
@ -97,11 +102,12 @@ if __name__ == "__main__":
config_host = str(config["host"]) config_host = str(config["host"])
config_port = int(config["port"]) config_port = int(config["port"])
config_private_key = asyncssh.import_private_key(args.pkey.read_text()) config_private_key = asyncssh.import_private_key(args.pkey.read_text())
for c in config["clients"]:
config_clients[str(c)] = asyncssh.import_authorized_keys(str(config["clients"][c]))
# read private key
server_public_key = config_private_key.export_public_key("openssh").decode().strip("\n\r") server_public_key = config_private_key.export_public_key("openssh").decode().strip("\n\r")
stderr.write(f"Server public key is \"{server_public_key}\"\n") stderr.write(f"Server public key is \"{server_public_key}\"\n")
stderr.flush() stderr.flush()
for c in config["clients"]:
config_clients[str(c)] = asyncssh.import_authorized_keys(str(config["clients"][c]))
# start server # start server
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete( loop.run_until_complete(