Refactored some of the server code
This commit is contained in:
parent
4e2e322066
commit
ff4860f630
1 changed files with 38 additions and 52 deletions
50
asserver.py
50
asserver.py
|
@ -8,7 +8,6 @@ import asyncio
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sys import stdout
|
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
|
|
||||||
import asyncssh
|
import asyncssh
|
||||||
|
@ -38,63 +37,50 @@ class SSHServer(asyncssh.SSHServer):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast(msg: str, use_stderr: bool = False):
|
def broadcast(msg: str):
|
||||||
# Broadcast a message to all connected clients
|
# Broadcast a message to all connected clients
|
||||||
assert type(msg) == str
|
assert type(msg) == str
|
||||||
msg = msg.strip("\r\n")
|
msg = msg.strip("\r\n")
|
||||||
if use_stderr:
|
|
||||||
msg += "\r\n" # we need CRLF for stderr apparently
|
|
||||||
for c in connected_clients:
|
|
||||||
c.stderr.write(msg)
|
|
||||||
else:
|
|
||||||
msg += "\n"
|
msg += "\n"
|
||||||
|
if enable_logging:
|
||||||
|
stderr.write(msg)
|
||||||
for c in connected_clients:
|
for c in connected_clients:
|
||||||
c.stdout.write(msg)
|
c.stdout.write(msg)
|
||||||
|
|
||||||
|
def cleanup(process: asyncssh.SSHServerProcess, username: str):
|
||||||
|
disconnected_msg = f"[disconnected] {username}"
|
||||||
|
process.exit(0)
|
||||||
|
connected_clients.remove(process)
|
||||||
|
broadcast(disconnected_msg)
|
||||||
|
|
||||||
async def handle_connection(process: asyncssh.SSHServerProcess):
|
async def handle_connection(process: asyncssh.SSHServerProcess):
|
||||||
connected_clients.append(process)
|
connected_clients.append(process)
|
||||||
username = process.get_extra_info("username")
|
username = process.get_extra_info("username")
|
||||||
try:
|
|
||||||
# hello there
|
# hello there
|
||||||
connected_msg = f"[connected] {username}\n"
|
connected_msg = f"[connected] {username}"
|
||||||
if enable_logging:
|
broadcast(connected_msg)
|
||||||
stderr.write(connected_msg)
|
|
||||||
broadcast(connected_msg, True)
|
|
||||||
if process.command is not None:
|
if process.command is not None:
|
||||||
# client has provided a command as a ssh commandline argument
|
# client has provided a command as a ssh commandline argument
|
||||||
line = process.command.strip("\r\n")
|
line = process.command.strip("\r\n")
|
||||||
msg = f"{username}: {line}\n"
|
msg = f"{username}: {line}\n"
|
||||||
if enable_logging:
|
|
||||||
stdout.write(msg)
|
|
||||||
broadcast(msg)
|
broadcast(msg)
|
||||||
|
cleanup(process, username)
|
||||||
else:
|
else:
|
||||||
# client wants an interactive session
|
async def listen():
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
async for line in process.stdin:
|
async for line in process.stdin:
|
||||||
if line == "": raise asyncssh.BreakReceived(0)
|
|
||||||
line = line.strip('\r\n')
|
line = line.strip('\r\n')
|
||||||
msg = f"{username}: {line}\n"
|
msg = f"{username}: {line}"
|
||||||
if enable_logging:
|
|
||||||
stdout.write(msg)
|
|
||||||
broadcast(msg)
|
broadcast(msg)
|
||||||
except asyncssh.TerminalSizeChanged:
|
except asyncssh.TerminalSizeChanged:
|
||||||
continue # we don't want to exit when the client changes its terminal size lol
|
await listen() # we don't want to exit yet.
|
||||||
finally: # but otherwise, we do want to
|
|
||||||
break # exit this loop.
|
|
||||||
except asyncssh.BreakReceived:
|
except asyncssh.BreakReceived:
|
||||||
pass # we don't want to write an error message on this exception
|
pass # we don't want to write an error message on this exception
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stderr.write(f"An error occured: {type(e).__name__} {e}\n")
|
stderr.write(f"An error occured: {type(e).__name__} {e}\n")
|
||||||
stderr.flush()
|
stderr.flush()
|
||||||
finally:
|
await listen()
|
||||||
# exit process, remove client from list, inform other clients
|
cleanup(process, username)
|
||||||
disconnected_msg = f"[disconnected] {username}\n"
|
|
||||||
process.exit(0)
|
|
||||||
connected_clients.remove(process)
|
|
||||||
if enable_logging:
|
|
||||||
stderr.write(disconnected_msg)
|
|
||||||
broadcast(disconnected_msg, True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -102,7 +88,7 @@ if __name__ == "__main__":
|
||||||
argp = ArgumentParser()
|
argp = ArgumentParser()
|
||||||
argp.add_argument("config", type=Path, help="The path to the config file")
|
argp.add_argument("config", type=Path, help="The path to the config file")
|
||||||
argp.add_argument("pkey", type=Path, help="The path to the ssh private key")
|
argp.add_argument("pkey", type=Path, help="The path to the ssh private key")
|
||||||
argp.add_argument("--log", action="store_true", help="Enable logging to stdout and stderr")
|
argp.add_argument("--log", action="store_true", help="Enable logging")
|
||||||
args = argp.parse_args()
|
args = argp.parse_args()
|
||||||
# read config
|
# read config
|
||||||
config = yaml.safe_load(args.config.read_text())
|
config = yaml.safe_load(args.config.read_text())
|
||||||
|
|
Reference in a new issue