Restructure code into a module
This commit is contained in:
parent
58a485f6ee
commit
98c2678cf8
6 changed files with 119 additions and 100 deletions
13
argh/log.py
Normal file
13
argh/log.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
_loglevels = [
|
||||||
|
"ERROR",
|
||||||
|
"WARN ",
|
||||||
|
"INFO ",
|
||||||
|
]
|
||||||
|
|
||||||
|
def log(*args, level: int = 2):
|
||||||
|
level = min(len(_loglevels)-1, max(0, level))
|
||||||
|
print(datetime.now().isoformat(), _loglevels[level], *args)
|
85
argh/node.py
Normal file
85
argh/node.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from hashlib import md5
|
||||||
|
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
|
from websockets.asyncio.server import broadcast
|
||||||
|
from websockets.asyncio.server import serve
|
||||||
|
from websockets.asyncio.server import ServerConnection
|
||||||
|
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
||||||
|
|
||||||
|
from yaml import safe_load as yml_load
|
||||||
|
|
||||||
|
from .log import log
|
||||||
|
from .nodeaddr import NodeAddr
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
|
||||||
|
max_known_hashes_size = 1024 * 1024
|
||||||
|
|
||||||
|
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
|
||||||
|
self.listen_address = listen_address
|
||||||
|
self.receivers = list(receivers)
|
||||||
|
for r in receivers:
|
||||||
|
if not isinstance(r, NodeAddr):
|
||||||
|
raise TypeError(f"{r} must be of type NodeAddr")
|
||||||
|
# internal
|
||||||
|
self._connections: dict = {} # identifer: connection object
|
||||||
|
self._known_hashes: list[str] = [] # list of message hashes
|
||||||
|
|
||||||
|
def _relay(self, msg: str):
|
||||||
|
msg_hash = md5(msg.encode()).digest()
|
||||||
|
if not msg_hash in self._known_hashes:
|
||||||
|
self._known_hashes.append(msg_hash)
|
||||||
|
broadcast(self._connections.values(), msg)
|
||||||
|
while len(self._known_hashes) > self.max_known_hashes_size:
|
||||||
|
self._known_hashes.pop(0)
|
||||||
|
|
||||||
|
async def _receiver_connection(self, r: NodeAddr):
|
||||||
|
identifier = r.identifier()
|
||||||
|
async for ws in connect(r.ws_uri()):
|
||||||
|
log(identifier)
|
||||||
|
self._connections[identifier] = ws
|
||||||
|
try:
|
||||||
|
async for msg in ws:
|
||||||
|
log(msg)
|
||||||
|
self._relay(msg)
|
||||||
|
except (ConnectionClosed, ConnectionClosedError):
|
||||||
|
del self._connections[identifier]
|
||||||
|
log(f"Connection from {identifier} closed", 1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def _server_connection(self, ws: ServerConnection):
|
||||||
|
host, port = ws.remote_address
|
||||||
|
addr = NodeAddr(host, port)
|
||||||
|
identifier = addr.identifier()
|
||||||
|
self._connections[identifier] = ws
|
||||||
|
log(identifier)
|
||||||
|
async for msg in ws:
|
||||||
|
log(msg)
|
||||||
|
self._relay(msg)
|
||||||
|
del self._connections[identifier]
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# connect to receivers
|
||||||
|
receiver_tasks = []
|
||||||
|
for r in self.receivers:
|
||||||
|
t = asyncio.create_task(self._receiver_connection(r))
|
||||||
|
receiver_tasks.append(t)
|
||||||
|
# server loop
|
||||||
|
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
|
||||||
|
await server.serve_forever()
|
||||||
|
# wait for receivers
|
||||||
|
await asyncio.gather(receiver_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def node_from_yml(yml_data: str) -> Node:
|
||||||
|
d = yml_load(yml_data)
|
||||||
|
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
|
||||||
|
receivers = []
|
||||||
|
for e in d["receivers"]:
|
||||||
|
receivers.append(NodeAddr(e["host"], e["port"]))
|
||||||
|
return Node(listen_address, receivers)
|
13
argh/nodeaddr.py
Normal file
13
argh/nodeaddr.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
|
class NodeAddr:
|
||||||
|
|
||||||
|
def __init__(self, host: str, port: int):
|
||||||
|
self.host = str(host)
|
||||||
|
self.port = int(port)
|
||||||
|
|
||||||
|
def ws_uri(self) -> str:
|
||||||
|
return f"ws://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
def identifier(self) -> str:
|
||||||
|
return f"{self.host}_{self.port}"
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
@ -13,6 +15,7 @@ async def run(host: str, port: int):
|
||||||
inp = input("> ")
|
inp = input("> ")
|
||||||
await ws.send(inp)
|
await ws.send(inp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
argp = ArgumentParser()
|
argp = ArgumentParser()
|
||||||
|
|
103
node.py
103
node.py
|
@ -1,110 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2025 Julian Müller (ChaoticByte)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from datetime import datetime
|
|
||||||
from hashlib import md5
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from websockets.asyncio.client import connect
|
from argh.node import node_from_yml
|
||||||
from websockets.asyncio.server import broadcast
|
|
||||||
from websockets.asyncio.server import serve
|
|
||||||
from websockets.asyncio.server import ServerConnection
|
|
||||||
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
|
||||||
|
|
||||||
from yaml import safe_load as yml_load
|
|
||||||
|
|
||||||
|
|
||||||
def log(msg):
|
|
||||||
print(f"{datetime.now().isoformat()} {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
class NodeAddr:
|
|
||||||
|
|
||||||
def __init__(self, host: str, port: int):
|
|
||||||
self.host = str(host)
|
|
||||||
self.port = int(port)
|
|
||||||
|
|
||||||
def ws_uri(self) -> str:
|
|
||||||
return f"ws://{self.host}:{self.port}"
|
|
||||||
|
|
||||||
def identifier(self) -> str:
|
|
||||||
return f"{self.host}_{self.port}"
|
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
|
||||||
|
|
||||||
max_known_hashes_size = 1024 * 1024
|
|
||||||
|
|
||||||
def __init__(self, listen_address: NodeAddr, receivers: list[NodeAddr]):
|
|
||||||
self.listen_address = listen_address
|
|
||||||
self.receivers = list(receivers)
|
|
||||||
for r in receivers:
|
|
||||||
if not isinstance(r, NodeAddr):
|
|
||||||
raise TypeError(f"{r} must be of type NodeAddr")
|
|
||||||
# internal
|
|
||||||
self._connections: dict = {} # identifer: connection object
|
|
||||||
self._known_hashes: list[str] = [] # list of message hashes
|
|
||||||
|
|
||||||
|
|
||||||
def _relay(self, msg: str):
|
|
||||||
msg_hash = md5(msg.encode()).digest()
|
|
||||||
if not msg_hash in self._known_hashes:
|
|
||||||
self._known_hashes.append(msg_hash)
|
|
||||||
broadcast(self._connections.values(), msg)
|
|
||||||
while len(self._known_hashes) > self.max_known_hashes_size:
|
|
||||||
self._known_hashes.pop(0)
|
|
||||||
|
|
||||||
|
|
||||||
async def _receiver_connection(self, r: NodeAddr):
|
|
||||||
identifier = r.identifier()
|
|
||||||
async for ws in connect(r.ws_uri()):
|
|
||||||
log(identifier)
|
|
||||||
self._connections[identifier] = ws
|
|
||||||
try:
|
|
||||||
async for msg in ws:
|
|
||||||
log(msg)
|
|
||||||
self._relay(msg)
|
|
||||||
except (ConnectionClosed, ConnectionClosedError):
|
|
||||||
del self._connections[identifier]
|
|
||||||
log(f"Connection from {identifier} closed")
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
async def _server_connection(self, ws: ServerConnection):
|
|
||||||
host, port = ws.remote_address
|
|
||||||
addr = NodeAddr(host, port)
|
|
||||||
identifier = addr.identifier()
|
|
||||||
self._connections[identifier] = ws
|
|
||||||
log(identifier)
|
|
||||||
async for msg in ws:
|
|
||||||
log(msg)
|
|
||||||
self._relay(msg)
|
|
||||||
del self._connections[identifier]
|
|
||||||
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
# connect to receivers
|
|
||||||
receiver_tasks = []
|
|
||||||
for r in self.receivers:
|
|
||||||
t = asyncio.create_task(self._receiver_connection(r))
|
|
||||||
receiver_tasks.append(t)
|
|
||||||
# server loop
|
|
||||||
async with serve(self._server_connection, self.listen_address.host, self.listen_address.port) as server:
|
|
||||||
await server.serve_forever()
|
|
||||||
# wait for receivers
|
|
||||||
await asyncio.gather(receiver_tasks)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def node_from_yml(yml_data: str) -> Node:
|
|
||||||
d = yml_load(yml_data)
|
|
||||||
listen_address = NodeAddr(d["listen"]["host"], d["listen"]["port"])
|
|
||||||
receivers = []
|
|
||||||
for e in d["receivers"]:
|
|
||||||
receivers.append(NodeAddr(e["host"], e["port"]))
|
|
||||||
return Node(listen_address, receivers)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue