Restructure code into a module

This commit is contained in:
ChaoticByte 2025-09-27 07:51:50 +02:00
parent 58a485f6ee
commit 98c2678cf8
No known key found for this signature in database
6 changed files with 119 additions and 100 deletions

13
argh/log.py Normal file
View file

@ -0,0 +1,13 @@
# Copyright (c) 2025 Julian Müller (ChaoticByte)
from datetime import datetime
_loglevels = [
"ERROR",
"WARN ",
"INFO ",
]
def log(*args, level: int = 2):
level = min(len(_loglevels)-1, max(0, level))
print(datetime.now().isoformat(), _loglevels[level], *args)

85
argh/node.py Normal file
View file

@ -0,0 +1,85 @@
# Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio
from hashlib import md5
from websockets.asyncio.client import connect
from websockets.asyncio.server import broadcast
from websockets.asyncio.server import serve
from websockets.asyncio.server import ServerConnection
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
from yaml import safe_load as yml_load
from .log import log
from .nodeaddr import NodeAddr
class Node:
max_known_hashes_size = 1024 * 1024
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
self.listen_address = listen_address
self.receivers = list(receivers)
for r in receivers:
if not isinstance(r, NodeAddr):
raise TypeError(f"{r} must be of type NodeAddr")
# internal
self._connections: dict = {} # identifer: connection object
self._known_hashes: list[str] = [] # list of message hashes
def _relay(self, msg: str):
msg_hash = md5(msg.encode()).digest()
if not msg_hash in self._known_hashes:
self._known_hashes.append(msg_hash)
broadcast(self._connections.values(), msg)
while len(self._known_hashes) > self.max_known_hashes_size:
self._known_hashes.pop(0)
async def _receiver_connection(self, r: NodeAddr):
identifier = r.identifier()
async for ws in connect(r.ws_uri()):
log(identifier)
self._connections[identifier] = ws
try:
async for msg in ws:
log(msg)
self._relay(msg)
except (ConnectionClosed, ConnectionClosedError):
del self._connections[identifier]
log(f"Connection from {identifier} closed", 1)
continue
async def _server_connection(self, ws: ServerConnection):
host, port = ws.remote_address
addr = NodeAddr(host, port)
identifier = addr.identifier()
self._connections[identifier] = ws
log(identifier)
async for msg in ws:
log(msg)
self._relay(msg)
del self._connections[identifier]
async def run(self):
# connect to receivers
receiver_tasks = []
for r in self.receivers:
t = asyncio.create_task(self._receiver_connection(r))
receiver_tasks.append(t)
# server loop
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
await server.serve_forever()
# wait for receivers
await asyncio.gather(receiver_tasks)
def node_from_yml(yml_data: str) -> Node:
d = yml_load(yml_data)
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
receivers = []
for e in d["receivers"]:
receivers.append(NodeAddr(e["host"], e["port"]))
return Node(listen_address, receivers)

13
argh/nodeaddr.py Normal file
View file

@ -0,0 +1,13 @@
# Copyright (c) 2025 Julian Müller (ChaoticByte)
class NodeAddr:
def __init__(self, host: str, port: int):
self.host = str(host)
self.port = int(port)
def ws_uri(self) -> str:
return f"ws://{self.host}:{self.port}"
def identifier(self) -> str:
return f"{self.host}_{self.port}"

View file

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio import asyncio
from argparse import ArgumentParser from argparse import ArgumentParser

View file

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio import asyncio
from argparse import ArgumentParser from argparse import ArgumentParser
@ -13,6 +15,7 @@ async def run(host: str, port: int):
inp = input("> ") inp = input("> ")
await ws.send(inp) await ws.send(inp)
if __name__ == "__main__": if __name__ == "__main__":
argp = ArgumentParser() argp = ArgumentParser()

103
node.py
View file

@ -1,110 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2025 Julian Müller (ChaoticByte)
import asyncio import asyncio
from argparse import ArgumentParser from argparse import ArgumentParser
from datetime import datetime
from hashlib import md5
from pathlib import Path from pathlib import Path
from websockets.asyncio.client import connect from argh.node import node_from_yml
from websockets.asyncio.server import broadcast
from websockets.asyncio.server import serve
from websockets.asyncio.server import ServerConnection
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
from yaml import safe_load as yml_load
def log(msg):
print(f"{datetime.now().isoformat()} {msg}")
class NodeAddr:
def __init__(self, host: str, port: int):
self.host = str(host)
self.port = int(port)
def ws_uri(self) -> str:
return f"ws://{self.host}:{self.port}"
def identifier(self) -> str:
return f"{self.host}_{self.port}"
class Node:
max_known_hashes_size = 1024 * 1024
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
self.listen_address = listen_address
self.receivers = list(receivers)
for r in receivers:
if not isinstance(r, NodeAddr):
raise TypeError(f"{r} must be of type NodeAddr")
# internal
self._connections: dict = {} # identifer: connection object
self._known_hashes: list[str] = [] # list of message hashes
def _relay(self, msg: str):
msg_hash = md5(msg.encode()).digest()
if not msg_hash in self._known_hashes:
self._known_hashes.append(msg_hash)
broadcast(self._connections.values(), msg)
while len(self._known_hashes) > self.max_known_hashes_size:
self._known_hashes.pop(0)
async def _receiver_connection(self, r: NodeAddr):
identifier = r.identifier()
async for ws in connect(r.ws_uri()):
log(identifier)
self._connections[identifier] = ws
try:
async for msg in ws:
log(msg)
self._relay(msg)
except (ConnectionClosed, ConnectionClosedError):
del self._connections[identifier]
log(f"Connection from {identifier} closed")
continue
async def _server_connection(self, ws: ServerConnection):
host, port = ws.remote_address
addr = NodeAddr(host, port)
identifier = addr.identifier()
self._connections[identifier] = ws
log(identifier)
async for msg in ws:
log(msg)
self._relay(msg)
del self._connections[identifier]
async def run(self):
# connect to receivers
receiver_tasks = []
for r in self.receivers:
t = asyncio.create_task(self._receiver_connection(r))
receiver_tasks.append(t)
# server loop
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
await server.serve_forever()
# wait for receivers
await asyncio.gather(receiver_tasks)
def node_from_yml(yml_data: str) -> Node:
d = yml_load(yml_data)
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
receivers = []
for e in d["receivers"]:
receivers.append(NodeAddr(e["host"], e["port"]))
return Node(listen_address, receivers)
if __name__ == "__main__": if __name__ == "__main__":