Got the thing to compile

This commit is contained in:
IQuant 2024-09-06 20:24:39 +03:00
parent 7386180a25
commit 868cc1ffaa
8 changed files with 333 additions and 85 deletions

27
noita-proxy/Cargo.lock generated
View file

@ -331,9 +331,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.17.1" version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae"
dependencies = [ dependencies = [
"bytemuck_derive", "bytemuck_derive",
] ]
@ -592,9 +592,9 @@ dependencies = [
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.13" version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -2025,6 +2025,7 @@ dependencies = [
"steamworks", "steamworks",
"tangled", "tangled",
"thiserror", "thiserror",
"tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"tungstenite", "tungstenite",
@ -2810,9 +2811,9 @@ checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.35" version = "0.38.36"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f" checksum = "3f55e80d50763938498dd5ebb18647174e0c76dc38c5505294bb224624f30f36"
dependencies = [ dependencies = [
"bitflags 2.6.0", "bitflags 2.6.0",
"errno", "errno",
@ -3302,7 +3303,7 @@ dependencies = [
name = "tangled" name = "tangled"
version = "0.3.0" version = "0.3.0"
dependencies = [ dependencies = [
"bincode", "bitcode",
"crossbeam", "crossbeam",
"dashmap", "dashmap",
"num-bigint", "num-bigint",
@ -3461,9 +3462,21 @@ dependencies = [
"mio", "mio",
"pin-project-lite", "pin-project-lite",
"socket2", "socket2",
"tokio-macros",
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "tokio-macros"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.26.0" version = "0.26.0"

View file

@ -47,6 +47,7 @@ shlex = "1.3.0"
quick-xml = { version = "0.36.0", features = ["serialize"] } quick-xml = { version = "0.36.0", features = ["serialize"] }
dashmap = "6.0.1" dashmap = "6.0.1"
eyre = "0.6.12" eyre = "0.6.12"
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] }
[build-dependencies] [build-dependencies]
winresource = "0.1.17" winresource = "0.1.17"

View file

@ -5,7 +5,9 @@ use eframe::{
use noita_proxy::{args::Args, recorder::replay_file, App}; use noita_proxy::{args::Args, recorder::replay_file, App};
use tracing::{info, level_filters::LevelFilter}; use tracing::{info, level_filters::LevelFilter};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
fn main() -> Result<(), eframe::Error> {
#[tokio::main(worker_threads = 2)]
async fn main() -> Result<(), eframe::Error> {
let my_subscriber = tracing_subscriber::FmtSubscriber::builder() let my_subscriber = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter( .with_env_filter(
EnvFilter::builder() EnvFilter::builder()

View file

@ -17,13 +17,13 @@ crossbeam = "0.8.2"
tracing = "0.1.36" tracing = "0.1.36"
dashmap = "6.0.1" dashmap = "6.0.1"
serde = {features = ["derive"], version = "1.0.142"} serde = {features = ["derive"], version = "1.0.142"}
bincode = "1.3.3"
quinn = "0.11.5" quinn = "0.11.5"
num-bigint = "0.4.6" num-bigint = "0.4.6"
rcgen = "0.13.1" rcgen = "0.13.1"
thiserror = "1.0.63" thiserror = "1.0.63"
tokio = "1.40.0" tokio = { version = "1.40.0", features = ["macros", "io-util", "sync"] }
bitcode = "0.6.3"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2.11", default-features = false, features = ["trace"]} test-log = { version = "0.2.16", default-features = false, features = ["trace"]}
tracing-subscriber = {version = "0.3", features = ["env-filter", "fmt"]} tracing-subscriber = {version = "0.3", features = ["env-filter", "fmt"]}

View file

@ -2,7 +2,7 @@
use std::{fmt::Display, time::Duration}; use std::{fmt::Display, time::Duration};
use serde::{Deserialize, Serialize}; use bitcode::{Decode, Encode};
/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings. /// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -20,7 +20,7 @@ pub struct Settings {
} }
/// Tells how reliable a message is. /// Tells how reliable a message is.
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)] #[derive(Encode, Decode, Clone, Copy, PartialEq, Debug)]
pub enum Reliability { pub enum Reliability {
/// Message will be delivered at most once. /// Message will be delivered at most once.
Unreliable, Unreliable,
@ -29,6 +29,7 @@ pub enum Reliability {
Reliable, Reliable,
} }
#[derive(Debug, Encode, Decode, Clone, Copy)]
pub enum Destination { pub enum Destination {
One(PeerId), One(PeerId),
Broadcast, Broadcast,
@ -36,7 +37,7 @@ pub enum Destination {
/// A value which refers to a specific peer. /// A value which refers to a specific peer.
/// Peer 0 is always the host. /// Peer 0 is always the host.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Encode, Decode)]
pub struct PeerId(pub u16); pub struct PeerId(pub u16);
/// Possible network events, returned by `Peer.recv()`. /// Possible network events, returned by `Peer.recv()`.
@ -101,3 +102,9 @@ impl Default for Settings {
} }
} }
} }
impl Destination {
pub(crate) fn is_broadcast(self) -> bool {
matches!(self, Destination::Broadcast)
}
}

View file

@ -7,6 +7,7 @@ use std::{
}, },
}; };
use bitcode::{Decode, Encode};
use crossbeam::{ use crossbeam::{
atomic::AtomicCell, atomic::AtomicCell,
channel::{unbounded, Receiver, Sender}, channel::{unbounded, Receiver, Sender},
@ -23,7 +24,7 @@ use quinn::{
}; };
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{error, info, warn}; use tracing::{debug, error, info, trace, warn};
use crate::{ use crate::{
common::{Destination, NetworkEvent, PeerId, PeerState, Reliability, Settings}, common::{Destination, NetworkEvent, PeerId, PeerState, Reliability, Settings},
@ -32,6 +33,13 @@ use crate::{
mod message_stream; mod message_stream;
#[derive(Debug, Encode, Decode)]
enum InternalMessage {
Normal(OutboundMessage),
RemoteConnected(PeerId),
RemoteDisconnected(PeerId),
}
#[derive(Default)] #[derive(Default)]
pub(crate) struct RemotePeer; pub(crate) struct RemotePeer;
@ -43,23 +51,35 @@ enum DirectConnectionError {
InitialExchangeFailed, InitialExchangeFailed,
#[error("Message read failed")] #[error("Message read failed")]
MessageIoFailed, MessageIoFailed,
#[error("Failed to decode message")]
DecodeError,
} }
struct DirectPeer { struct DirectPeer {
my_id: PeerId, my_id: PeerId,
remote_id: PeerId, remote_id: PeerId,
send_stream: message_stream::SendMessageStream, send_stream: message_stream::SendMessageStream<InternalMessage>,
} }
impl DirectPeer { impl DirectPeer {
async fn recv_task(shared: Arc<Shared>, recv_stream: RecvStream, remote_id: PeerId) { async fn recv_task(shared: Arc<Shared>, recv_stream: RecvStream, remote_id: PeerId) {
let mut recv_stream = message_stream::RecvMessageStream::new(recv_stream); let mut recv_stream = message_stream::RecvMessageStream::new(recv_stream);
while let Ok(msg) = recv_stream.recv().await { while let Ok(msg) = recv_stream.recv().await {
if let Err(err) = shared.incoming_messages.0.send((remote_id, msg)) { trace!("Received message from {remote_id}");
if let Err(err) = shared
.internal_incoming_messages_s
.send((remote_id, msg))
.await
{
warn!("Could not send message to channel: {err}. Stopping."); warn!("Could not send message to channel: {err}. Stopping.");
break; break;
} }
} }
shared
.internal_events_s
.send(InternalEvent::Disconnected(remote_id))
.await
.ok();
} }
async fn accept( async fn accept(
@ -82,6 +102,7 @@ impl DirectPeer {
let (send_stream, recv_stream) = connection.open_bi().await?; let (send_stream, recv_stream) = connection.open_bi().await?;
tokio::spawn(Self::recv_task(shared, recv_stream, assigned_peer_id)); tokio::spawn(Self::recv_task(shared, recv_stream, assigned_peer_id));
debug!("Server: spawned recv task");
Ok(Self { Ok(Self {
my_id: PeerId::HOST, my_id: PeerId::HOST,
@ -103,9 +124,11 @@ impl DirectPeer {
.read_u16() .read_u16()
.await .await
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?; .map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
debug!("Got peer id {peer_id}");
let (send_stream, recv_stream) = connection.open_bi().await?; let (send_stream, recv_stream) = connection.open_bi().await?;
tokio::spawn(Self::recv_task(shared, recv_stream, PeerId::HOST)); tokio::spawn(Self::recv_task(shared, recv_stream, PeerId::HOST));
debug!("Client: spawned recv task");
Ok(Self { Ok(Self {
my_id: PeerId(peer_id), my_id: PeerId(peer_id),
@ -115,7 +138,9 @@ impl DirectPeer {
} }
} }
#[derive(Debug, Encode, Decode, Clone)]
pub(crate) struct OutboundMessage { pub(crate) struct OutboundMessage {
pub src: PeerId,
pub dst: Destination, pub dst: Destination,
pub reliability: Reliability, pub reliability: Reliability,
pub data: Vec<u8>, pub data: Vec<u8>,
@ -131,10 +156,14 @@ pub enum TangledInitError {
CouldNotConnectToHost(ConnectError), CouldNotConnectToHost(ConnectError),
} }
enum InternalEvent {
Connected(PeerId),
Disconnected(PeerId),
}
pub(crate) struct Shared { pub(crate) struct Shared {
pub settings: Settings,
pub inbound_channel: Channel<NetworkEvent>, pub inbound_channel: Channel<NetworkEvent>,
pub outbound_channel: Channel<OutboundMessage>, pub outbound_messages_s: tokio::sync::mpsc::Sender<OutboundMessage>,
pub keep_alive: AtomicBool, pub keep_alive: AtomicBool,
pub peer_state: AtomicCell<PeerState>, pub peer_state: AtomicCell<PeerState>,
pub remote_peers: DashMap<PeerId, RemotePeer>, pub remote_peers: DashMap<PeerId, RemotePeer>,
@ -142,28 +171,8 @@ pub(crate) struct Shared {
pub my_id: AtomicCell<Option<PeerId>>, pub my_id: AtomicCell<Option<PeerId>>,
// ConnectionManager-specific stuff // ConnectionManager-specific stuff
direct_peers: DashMap<PeerId, DirectPeer>, direct_peers: DashMap<PeerId, DirectPeer>,
incoming_messages: Channel<(PeerId, Vec<u8>)>, internal_incoming_messages_s: tokio::sync::mpsc::Sender<(PeerId, InternalMessage)>,
} internal_events_s: tokio::sync::mpsc::Sender<InternalEvent>,
impl Shared {
pub(crate) fn new(host_addr: Option<SocketAddr>, settings: Option<Settings>) -> Self {
Self {
inbound_channel: unbounded(),
outbound_channel: unbounded(),
keep_alive: AtomicBool::new(true),
host_addr,
peer_state: Default::default(),
remote_peers: Default::default(),
my_id: AtomicCell::new(if host_addr.is_none() {
Some(PeerId(0))
} else {
None
}),
settings: settings.unwrap_or_default(),
direct_peers: DashMap::default(),
incoming_messages: unbounded(),
}
}
} }
pub(crate) struct ConnectionManager { pub(crate) struct ConnectionManager {
@ -171,18 +180,42 @@ pub(crate) struct ConnectionManager {
endpoint: Endpoint, endpoint: Endpoint,
host_conn: Option<DirectPeer>, host_conn: Option<DirectPeer>,
is_server: bool, is_server: bool,
incoming_messages_r: tokio::sync::mpsc::Receiver<(PeerId, InternalMessage)>,
outbound_messages_r: tokio::sync::mpsc::Receiver<OutboundMessage>,
internal_events_r: tokio::sync::mpsc::Receiver<InternalEvent>,
} }
impl ConnectionManager { impl ConnectionManager {
pub(crate) fn new(shared: Arc<Shared>, addr: SocketAddr) -> Result<Self, TangledInitError> { pub(crate) fn new(
let is_server = shared.host_addr.is_none(); host_addr: Option<SocketAddr>,
_settings: Option<Settings>,
bind_addr: SocketAddr,
) -> Result<Self, TangledInitError> {
let is_server = host_addr.is_none();
let (internal_incoming_messages_s, incoming_messages_r) = tokio::sync::mpsc::channel(512);
let (outbound_messages_s, outbound_messages_r) = tokio::sync::mpsc::channel(512);
let (internal_events_s, internal_events_r) = tokio::sync::mpsc::channel(512);
let shared = Arc::new(Shared {
inbound_channel: unbounded(),
outbound_messages_s,
keep_alive: AtomicBool::new(true),
host_addr,
peer_state: Default::default(),
remote_peers: Default::default(),
my_id: AtomicCell::new(is_server.then_some(PeerId(0))),
direct_peers: DashMap::default(),
internal_incoming_messages_s,
internal_events_s,
});
let config = default_server_config(); let config = default_server_config();
let mut endpoint = if is_server { let mut endpoint = if is_server {
Endpoint::server(config, addr).map_err(TangledInitError::CouldNotCreateEndpoint)? Endpoint::server(config, bind_addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
} else { } else {
Endpoint::client(addr).map_err(TangledInitError::CouldNotCreateEndpoint)? Endpoint::client(bind_addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
}; };
endpoint.set_default_client_config(ClientConfig::new(Arc::new( endpoint.set_default_client_config(ClientConfig::new(Arc::new(
@ -200,6 +233,9 @@ impl ConnectionManager {
is_server, is_server,
endpoint, endpoint,
host_conn: None, host_conn: None,
incoming_messages_r,
outbound_messages_r,
internal_events_r,
}) })
} }
@ -215,6 +251,11 @@ impl ConnectionManager {
shared shared
.direct_peers .direct_peers
.insert(PeerId(peer_id_counter), direct_peer); .insert(PeerId(peer_id_counter), direct_peer);
shared
.internal_events_s
.send(InternalEvent::Connected(PeerId(peer_id_counter)))
.await
.expect("channel to be open");
peer_id_counter += 1; peer_id_counter += 1;
} }
Err(err) => { Err(err) => {
@ -224,10 +265,142 @@ impl ConnectionManager {
} }
} }
async fn handle_incoming_message(&mut self, msg: InternalMessage) {
match msg {
InternalMessage::Normal(msg) => {
if self.is_server && msg.dst.is_broadcast() {
self.server_send_to_peers(msg.clone()).await;
}
self.shared
.inbound_channel
.0
.send(NetworkEvent::Message(crate::Message {
src: msg.src,
data: msg.data,
}))
.expect("channel to be open");
}
// TODO this might deadlock if internal_events_s is full.
InternalMessage::RemoteConnected(peer_id) => {
debug!("Got notified of peer {peer_id}");
self.shared
.internal_events_s
.send(InternalEvent::Connected(peer_id))
.await
.expect("channel to be open");
}
InternalMessage::RemoteDisconnected(peer_id) => self
.shared
.internal_events_s
.send(InternalEvent::Disconnected(peer_id))
.await
.expect("channel to be open"),
}
}
async fn handle_internal_event(&mut self, ev: InternalEvent) {
match ev {
InternalEvent::Connected(peer_id) => {
info!("Peer {} connected", peer_id);
self.shared
.inbound_channel
.0
.send(NetworkEvent::PeerConnected(peer_id))
.expect("channel to be open");
self.shared
.remote_peers
.insert(peer_id, RemotePeer::default());
if self.is_server {
self.server_broadcast_internal_message(
PeerId::HOST,
InternalMessage::RemoteConnected(peer_id),
)
.await;
let peers = self
.shared
.remote_peers
.iter()
.map(|i| *i.key())
.collect::<Vec<_>>();
for conn_peer in peers {
debug!("Notifying peer of {conn_peer}");
self.server_send_internal_message(
peer_id,
&InternalMessage::RemoteConnected(conn_peer),
)
.await;
}
}
}
InternalEvent::Disconnected(peer_id) => {
info!("Peer {} disconnected", peer_id);
self.shared.direct_peers.remove(&peer_id);
self.shared
.inbound_channel
.0
.send(NetworkEvent::PeerDisconnected(peer_id))
.expect("channel to be open");
self.shared.remote_peers.remove(&peer_id);
if self.is_server {
self.server_broadcast_internal_message(
PeerId::HOST,
InternalMessage::RemoteDisconnected(peer_id),
)
.await;
}
}
}
}
async fn server_send_to_peers(&mut self, msg: OutboundMessage) {
match msg.dst {
Destination::One(peer_id) => {
self.server_send_internal_message(peer_id, &InternalMessage::Normal(msg))
.await;
}
Destination::Broadcast => {
let msg_src = msg.src;
let value = InternalMessage::Normal(msg);
self.server_broadcast_internal_message(msg_src, value).await;
}
}
}
async fn server_send_internal_message(&mut self, peer_id: PeerId, msg: &InternalMessage) {
let peer = self.shared.direct_peers.get_mut(&peer_id);
// TODO handle lack of peer?
if let Some(mut peer) = peer {
// TODO handle errors
peer.send_stream.send(msg).await.ok();
}
}
async fn server_broadcast_internal_message(
&mut self,
excluded: PeerId,
value: InternalMessage,
) {
for mut peer in self.shared.direct_peers.iter_mut() {
let peer_id = *peer.key();
if peer_id != excluded {
// TODO handle errors
peer.send_stream.send(&value).await.ok();
}
}
}
async fn astart(mut self, host_conn: Option<Connecting>) { async fn astart(mut self, host_conn: Option<Connecting>) {
debug!("astart running");
if let Some(host_conn) = host_conn { if let Some(host_conn) = host_conn {
match DirectPeer::connect(self.shared.clone(), host_conn).await { match DirectPeer::connect(self.shared.clone(), host_conn).await {
Ok(host_conn) => { Ok(host_conn) => {
self.shared.my_id.store(Some(host_conn.my_id));
self.shared
.internal_events_s
.send(InternalEvent::Connected(host_conn.remote_id))
.await
.expect("channel to be open");
self.host_conn = Some(host_conn); self.host_conn = Some(host_conn);
} }
Err(err) => { Err(err) => {
@ -240,7 +413,30 @@ impl ConnectionManager {
if self.is_server { if self.is_server {
let endpoint = self.endpoint.clone(); let endpoint = self.endpoint.clone();
tokio::spawn(Self::accept_connections(self.shared.clone(), endpoint)); tokio::spawn(Self::accept_connections(self.shared.clone(), endpoint));
info!("Started connection acceptor task"); debug!("Started connection acceptor task");
}
while self.shared.keep_alive.load(Ordering::Relaxed) {
tokio::select! {
msg = self.incoming_messages_r.recv() => {
let msg = msg.expect("channel to not be closed");
self.handle_incoming_message(msg.1).await;
}
msg = self.outbound_messages_r.recv() => {
let msg = msg.expect("channel to not be closed");
if self.is_server {
self.server_send_to_peers(msg).await;
} else {
// TODO handle error
self.host_conn.as_mut().unwrap().send_stream.send(&InternalMessage::Normal(msg)).await.ok();
}
}
ev = self.internal_events_r.recv() => {
let ev = ev.expect("channel to not be closed");
self.handle_internal_event(ev).await;
}
}
} }
} }
@ -256,9 +452,14 @@ impl ConnectionManager {
}) })
.transpose()?; .transpose()?;
debug!("Spawning astart task");
tokio::spawn(self.astart(host_conn)); tokio::spawn(self.astart(host_conn));
Ok(()) Ok(())
} }
pub(crate) fn shared(&self) -> Arc<Shared> {
self.shared.clone()
}
} }
fn default_server_config() -> ServerConfig { fn default_server_config() -> ServerConfig {

View file

@ -1,24 +1,31 @@
use std::io; use std::marker::PhantomData;
use bitcode::{DecodeOwned, Encode};
use quinn::{RecvStream, SendStream}; use quinn::{RecvStream, SendStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace};
use super::DirectConnectionError; use super::DirectConnectionError;
pub(crate) struct SendMessageStream { pub(crate) struct SendMessageStream<Msg> {
inner: SendStream, inner: SendStream,
_phantom: PhantomData<fn(Msg)>,
} }
pub(crate) struct RecvMessageStream { pub(crate) struct RecvMessageStream<Msg> {
inner: RecvStream, inner: RecvStream,
_phantom: PhantomData<fn() -> Msg>,
} }
impl SendMessageStream { impl<Msg: Encode> SendMessageStream<Msg> {
pub(crate) fn new(inner: SendStream) -> Self { pub(crate) fn new(inner: SendStream) -> Self {
Self { inner } Self {
inner,
_phantom: PhantomData,
}
} }
pub(crate) async fn send(&mut self, msg: &[u8]) -> Result<(), DirectConnectionError> { async fn send_raw(&mut self, msg: &[u8]) -> Result<(), DirectConnectionError> {
self.inner self.inner
.write_u32( .write_u32(
msg.len() msg.len()
@ -33,19 +40,29 @@ impl SendMessageStream {
.map_err(|_err| DirectConnectionError::MessageIoFailed)?; .map_err(|_err| DirectConnectionError::MessageIoFailed)?;
Ok(()) Ok(())
} }
pub(crate) async fn send(&mut self, msg: &Msg) -> Result<(), DirectConnectionError> {
trace!("Sending message");
let msg = bitcode::encode(msg);
self.send_raw(&msg).await
}
} }
impl RecvMessageStream { impl<Msg: DecodeOwned> RecvMessageStream<Msg> {
pub(crate) fn new(inner: RecvStream) -> Self { pub(crate) fn new(inner: RecvStream) -> Self {
Self { inner } Self {
inner,
_phantom: PhantomData,
}
} }
pub(crate) async fn recv(&mut self) -> Result<Vec<u8>, DirectConnectionError> { async fn recv_raw(&mut self) -> Result<Vec<u8>, DirectConnectionError> {
let len = self let len = self
.inner .inner
.read_u32() .read_u32()
.await .await
.map_err(|_err| DirectConnectionError::MessageIoFailed)?; .map_err(|_err| DirectConnectionError::MessageIoFailed)?;
trace!("Expecting message of {len}");
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
self.inner self.inner
.read_exact(&mut buf) .read_exact(&mut buf)
@ -53,4 +70,8 @@ impl RecvMessageStream {
.map_err(|_err| DirectConnectionError::MessageIoFailed)?; .map_err(|_err| DirectConnectionError::MessageIoFailed)?;
Ok(buf) Ok(buf)
} }
pub(crate) async fn recv(&mut self) -> Result<Msg, DirectConnectionError> {
let raw = self.recv_raw().await?;
bitcode::decode(&raw).map_err(|_| DirectConnectionError::DecodeError)
}
} }

View file

@ -1,23 +1,11 @@
//! Tangled - a work-in-progress UDP networking crate. //! Tangled - a work-in-progress UDP networking crate.
use std::{ use std::{net::SocketAddr, sync::Arc};
fmt::Display,
io,
net::{SocketAddr, UdpSocket},
sync::{atomic::AtomicBool, Arc},
time::Duration,
};
use connection_manager::{ use connection_manager::{
ConnectionManager, OutboundMessage, RemotePeer, Shared, TangledInitError, ConnectionManager, OutboundMessage, RemotePeer, Shared, TangledInitError,
}; };
use crossbeam::{
self,
atomic::AtomicCell,
channel::{unbounded, Receiver, Sender},
};
use dashmap::DashMap;
pub use error::NetError; pub use error::NetError;
const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500 const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500
@ -31,6 +19,7 @@ mod error;
mod helpers; mod helpers;
pub use common::*; pub use common::*;
use tracing::debug;
/// Represents a network endpoint. Can be constructed in either `host` or `client` mode. /// Represents a network endpoint. Can be constructed in either `host` or `client` mode.
/// ///
@ -46,7 +35,8 @@ impl Peer {
host_addr: Option<SocketAddr>, host_addr: Option<SocketAddr>,
settings: Option<Settings>, settings: Option<Settings>,
) -> Result<Self, TangledInitError> { ) -> Result<Self, TangledInitError> {
let shared = Arc::new(Shared::new(host_addr, settings)); let connection_manager = ConnectionManager::new(host_addr, settings, bind_addr)?;
let shared = connection_manager.shared();
if host_addr.is_none() { if host_addr.is_none() {
shared.remote_peers.insert(PeerId(0), RemotePeer::default()); shared.remote_peers.insert(PeerId(0), RemotePeer::default());
shared shared
@ -55,7 +45,8 @@ impl Peer {
.send(NetworkEvent::PeerConnected(PeerId(0))) .send(NetworkEvent::PeerConnected(PeerId(0)))
.unwrap(); .unwrap();
} }
ConnectionManager::new(Arc::clone(&shared), bind_addr)?.start()?; debug!("Starting connection manager");
connection_manager.start()?;
Ok(Peer { shared }) Ok(Peer { shared })
} }
@ -98,11 +89,15 @@ impl Peer {
if data.len() > MAX_MESSAGE_LEN { if data.len() > MAX_MESSAGE_LEN {
return Err(NetError::MessageTooLong); return Err(NetError::MessageTooLong);
} }
self.shared.outbound_channel.0.send(OutboundMessage { self.shared
dst: destination, .outbound_messages_s
data, .blocking_send(OutboundMessage {
reliability, src: self.my_id().expect("expected to know my_id by this point"),
})?; dst: destination,
data,
reliability,
})
.expect("channel to be open");
Ok(()) Ok(())
} }
@ -150,10 +145,19 @@ impl Drop for Peer {
mod test { mod test {
use std::{thread, time::Duration}; use std::{thread, time::Duration};
use tracing::info;
use crate::{common::Message, NetworkEvent, Peer, PeerId, Reliability, Settings}; use crate::{common::Message, NetworkEvent, Peer, PeerId, Reliability, Settings};
#[test_log::test] #[test_log::test(tokio::test)]
fn test_peer() { async fn test_create_host() {
let addr = "127.0.0.1:55999".parse().unwrap();
let _host = Peer::host(addr, None).unwrap();
}
#[test_log::test(tokio::test)]
async fn test_peer() {
info!("Starting test_peer");
let settings = Some(Settings { let settings = Some(Settings {
confirm_max_period: Duration::from_millis(100), confirm_max_period: Duration::from_millis(100),
connection_timeout: Duration::from_millis(1000), connection_timeout: Duration::from_millis(1000),
@ -163,13 +167,12 @@ mod test {
let host = Peer::host(addr, settings.clone()).unwrap(); let host = Peer::host(addr, settings.clone()).unwrap();
assert_eq!(host.shared.remote_peers.len(), 1); assert_eq!(host.shared.remote_peers.len(), 1);
let peer = Peer::connect(addr, settings.clone()).unwrap(); let peer = Peer::connect(addr, settings.clone()).unwrap();
thread::sleep(Duration::from_millis(100)); tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(peer.shared.remote_peers.len(), 2); assert_eq!(peer.shared.remote_peers.len(), 2);
assert_eq!(host.shared.remote_peers.len(), 2);
let data = vec![128, 51, 32]; let data = vec![128, 51, 32];
peer.send(PeerId(0), data.clone(), Reliability::Reliable) peer.send(PeerId(0), data.clone(), Reliability::Reliable)
.unwrap(); .unwrap();
thread::sleep(Duration::from_millis(10)); tokio::time::sleep(Duration::from_millis(10)).await;
let host_events: Vec<_> = host.recv().collect(); let host_events: Vec<_> = host.recv().collect();
assert!(host_events.contains(&NetworkEvent::PeerConnected(PeerId(1)))); assert!(host_events.contains(&NetworkEvent::PeerConnected(PeerId(1))));
assert!(host_events.contains(&NetworkEvent::Message(Message { assert!(host_events.contains(&NetworkEvent::Message(Message {
@ -180,7 +183,7 @@ mod test {
assert!(peer_events.contains(&NetworkEvent::PeerConnected(PeerId(0)))); assert!(peer_events.contains(&NetworkEvent::PeerConnected(PeerId(0))));
assert!(peer_events.contains(&NetworkEvent::PeerConnected(PeerId(1)))); assert!(peer_events.contains(&NetworkEvent::PeerConnected(PeerId(1))));
drop(peer); drop(peer);
thread::sleep(Duration::from_millis(1200)); tokio::time::sleep(Duration::from_millis(1200)).await;
assert_eq!( assert_eq!(
host.recv().next(), host.recv().next(),
Some(NetworkEvent::PeerDisconnected(PeerId(1))) Some(NetworkEvent::PeerDisconnected(PeerId(1)))
@ -188,8 +191,8 @@ mod test {
assert_eq!(host.shared.remote_peers.len(), 1); assert_eq!(host.shared.remote_peers.len(), 1);
} }
#[test_log::test] #[test_log::test(tokio::test)]
fn test_broadcast() { async fn test_broadcast() {
let settings = Some(Settings { let settings = Some(Settings {
confirm_max_period: Duration::from_millis(100), confirm_max_period: Duration::from_millis(100),
connection_timeout: Duration::from_millis(1000), connection_timeout: Duration::from_millis(1000),
@ -227,8 +230,8 @@ mod test {
}))); })));
} }
#[test_log::test] #[test_log::test(tokio::test)]
fn test_host_has_conn() { async fn test_host_has_conn() {
let settings = Some(Settings { let settings = Some(Settings {
confirm_max_period: Duration::from_millis(100), confirm_max_period: Duration::from_millis(100),
connection_timeout: Duration::from_millis(1000), connection_timeout: Duration::from_millis(1000),