Restructure code into a module
This commit is contained in:
parent
58a485f6ee
commit
98c2678cf8
6 changed files with 119 additions and 100 deletions
103
node.py
103
node.py
|
@ -1,110 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||
|
||||
import asyncio
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from datetime import datetime
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
|
||||
from websockets.asyncio.client import connect
|
||||
from websockets.asyncio.server import broadcast
|
||||
from websockets.asyncio.server import serve
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
||||
|
||||
from yaml import safe_load as yml_load
|
||||
|
||||
|
||||
def log(msg):
|
||||
print(f"{datetime.now().isoformat()} {msg}")
|
||||
|
||||
|
||||
class NodeAddr:
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self.host = str(host)
|
||||
self.port = int(port)
|
||||
|
||||
def ws_uri(self) -> str:
|
||||
return f"ws://{self.host}:{self.port}"
|
||||
|
||||
def identifier(self) -> str:
|
||||
return f"{self.host}_{self.port}"
|
||||
|
||||
|
||||
class Node:
|
||||
|
||||
max_known_hashes_size = 1024 * 1024
|
||||
|
||||
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
|
||||
self.listen_address = listen_address
|
||||
self.receivers = list(receivers)
|
||||
for r in receivers:
|
||||
if not isinstance(r, NodeAddr):
|
||||
raise TypeError(f"{r} must be of type NodeAddr")
|
||||
# internal
|
||||
self._connections: dict = {} # identifer: connection object
|
||||
self._known_hashes: list[str] = [] # list of message hashes
|
||||
|
||||
|
||||
def _relay(self, msg: str):
|
||||
msg_hash = md5(msg.encode()).digest()
|
||||
if not msg_hash in self._known_hashes:
|
||||
self._known_hashes.append(msg_hash)
|
||||
broadcast(self._connections.values(), msg)
|
||||
while len(self._known_hashes) > self.max_known_hashes_size:
|
||||
self._known_hashes.pop(0)
|
||||
|
||||
|
||||
async def _receiver_connection(self, r: NodeAddr):
|
||||
identifier = r.identifier()
|
||||
async for ws in connect(r.ws_uri()):
|
||||
log(identifier)
|
||||
self._connections[identifier] = ws
|
||||
try:
|
||||
async for msg in ws:
|
||||
log(msg)
|
||||
self._relay(msg)
|
||||
except (ConnectionClosed, ConnectionClosedError):
|
||||
del self._connections[identifier]
|
||||
log(f"Connection from {identifier} closed")
|
||||
continue
|
||||
|
||||
|
||||
async def _server_connection(self, ws: ServerConnection):
|
||||
host, port = ws.remote_address
|
||||
addr = NodeAddr(host, port)
|
||||
identifier = addr.identifier()
|
||||
self._connections[identifier] = ws
|
||||
log(identifier)
|
||||
async for msg in ws:
|
||||
log(msg)
|
||||
self._relay(msg)
|
||||
del self._connections[identifier]
|
||||
|
||||
|
||||
async def run(self):
|
||||
# connect to receivers
|
||||
receiver_tasks = []
|
||||
for r in self.receivers:
|
||||
t = asyncio.create_task(self._receiver_connection(r))
|
||||
receiver_tasks.append(t)
|
||||
# server loop
|
||||
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
|
||||
await server.serve_forever()
|
||||
# wait for receivers
|
||||
await asyncio.gather(receiver_tasks)
|
||||
|
||||
|
||||
|
||||
def node_from_yml(yml_data: str) -> Node:
|
||||
d = yml_load(yml_data)
|
||||
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
|
||||
receivers = []
|
||||
for e in d["receivers"]:
|
||||
receivers.append(NodeAddr(e["host"], e["port"]))
|
||||
return Node(listen_address, receivers)
|
||||
from argh.node import node_from_yml
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue