85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
|
|
|
import asyncio
|
|
|
|
from hashlib import md5
|
|
|
|
from websockets.asyncio.client import connect
|
|
from websockets.asyncio.server import broadcast
|
|
from websockets.asyncio.server import serve
|
|
from websockets.asyncio.server import ServerConnection
|
|
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
|
|
|
from yaml import safe_load as yml_load
|
|
|
|
from .log import log
|
|
from .nodeaddr import NodeAddr
|
|
|
|
|
|
class Node:
|
|
|
|
max_known_hashes_size = 1024 * 1024
|
|
|
|
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
|
|
self.listen_address = listen_address
|
|
self.receivers = list(receivers)
|
|
for r in receivers:
|
|
if not isinstance(r, NodeAddr):
|
|
raise TypeError(f"{r} must be of type NodeAddr")
|
|
# internal
|
|
self._connections: dict = {} # identifer: connection object
|
|
self._known_hashes: list[str] = [] # list of message hashes
|
|
|
|
def _relay(self, msg: str):
|
|
msg_hash = md5(msg.encode()).digest()
|
|
if not msg_hash in self._known_hashes:
|
|
self._known_hashes.append(msg_hash)
|
|
broadcast(self._connections.values(), msg)
|
|
while len(self._known_hashes) > self.max_known_hashes_size:
|
|
self._known_hashes.pop(0)
|
|
|
|
async def _receiver_connection(self, r: NodeAddr):
|
|
identifier = r.identifier()
|
|
async for ws in connect(r.ws_uri()):
|
|
log(identifier)
|
|
self._connections[identifier] = ws
|
|
try:
|
|
async for msg in ws:
|
|
log(msg)
|
|
self._relay(msg)
|
|
except (ConnectionClosed, ConnectionClosedError):
|
|
del self._connections[identifier]
|
|
log(f"Connection from {identifier} closed", 1)
|
|
continue
|
|
|
|
async def _server_connection(self, ws: ServerConnection):
|
|
host, port = ws.remote_address
|
|
addr = NodeAddr(host, port)
|
|
identifier = addr.identifier()
|
|
self._connections[identifier] = ws
|
|
log(identifier)
|
|
async for msg in ws:
|
|
log(msg)
|
|
self._relay(msg)
|
|
del self._connections[identifier]
|
|
|
|
async def run(self):
|
|
# connect to receivers
|
|
receiver_tasks = []
|
|
for r in self.receivers:
|
|
t = asyncio.create_task(self._receiver_connection(r))
|
|
receiver_tasks.append(t)
|
|
# server loop
|
|
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
|
|
await server.serve_forever()
|
|
# wait for receivers
|
|
await asyncio.gather(receiver_tasks)
|
|
|
|
|
|
def node_from_yml(yml_data: str) -> Node:
|
|
d = yml_load(yml_data)
|
|
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
|
|
receivers = []
|
|
for e in d["receivers"]:
|
|
receivers.append(NodeAddr(e["host"], e["port"]))
|
|
return Node(listen_address, receivers)
|