Handle SIGINT and SIGTERM (more) gracefully

This commit is contained in:
ChaoticByte 2025-09-27 11:48:07 +02:00
parent 1b244b0c21
commit 3348c53aea
No known key found for this signature in database
2 changed files with 28 additions and 7 deletions

View file

@ -1,6 +1,7 @@
# Copyright (c) 2025 Julian Müller (ChaoticByte) # Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio import asyncio
import signal
from hashlib import md5 from hashlib import md5
@ -58,10 +59,13 @@ class Node:
identifier = addr.identifier() identifier = addr.identifier()
self._connections[identifier] = ws self._connections[identifier] = ws
log("Accepted connection", identifier) log("Accepted connection", identifier)
async for msg in ws: try:
self._relay(msg) async for msg in ws:
log("Lost connection", identifier) self._relay(msg)
del self._connections[identifier] except:
log("Lost connection", identifier)
finally:
del self._connections[identifier]
async def run(self): async def run(self):
# connect to receivers # connect to receivers
@ -71,9 +75,17 @@ class Node:
receiver_tasks.append(t) receiver_tasks.append(t)
# server loop # server loop
async with serve(self._conn_from_node_or_client, self.listen_address.host, self.listen_address.port) as server: async with serve(self._conn_from_node_or_client, self.listen_address.host, self.listen_address.port) as server:
await server.serve_forever() loop = asyncio.get_running_loop()
# wait for receivers
await asyncio.gather(receiver_tasks) def _close():
log("Cancelling receiver connections")
for r in receiver_tasks:
r.cancel()
log("Closing server...")
server.close()
loop.add_signal_handler(signal.SIGTERM, _close)
await server.wait_closed()
def node_from_yml(yml_data: str) -> Node: def node_from_yml(yml_data: str) -> Node:

View file

@ -3,10 +3,12 @@
# Copyright (c) 2025 Julian Müller (ChaoticByte) # Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio import asyncio
import signal
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from argh.log import log
from argh.node import node_from_yml from argh.node import node_from_yml
@ -18,4 +20,11 @@ if __name__ == "__main__":
node = node_from_yml(args.config.read_text()) node = node_from_yml(args.config.read_text())
def _sigint(*args):
# don't need special sigint behaviour
# -> send SIGTERM
signal.raise_signal(signal.SIGTERM)
signal.signal(signal.SIGINT, _sigint)
asyncio.run(node.run()) asyncio.run(node.run())
log("Bye")