From cf76733005abfe020a9403088a985c7eaff51ba8 Mon Sep 17 00:00:00 2001 From: IQuant Date: Thu, 5 Sep 2024 23:00:12 +0300 Subject: [PATCH] WIP tangled on QUIC --- noita-proxy/Cargo.lock | 106 ++- noita-proxy/tangled/Cargo.toml | 7 +- noita-proxy/tangled/src/common.rs | 103 +++ noita-proxy/tangled/src/connection_manager.rs | 250 +++++++ noita-proxy/tangled/src/helpers.rs | 60 ++ noita-proxy/tangled/src/lib.rs | 125 +--- noita-proxy/tangled/src/reactor.rs | 679 ------------------ noita-proxy/tangled/src/util.rs | 105 --- 8 files changed, 543 insertions(+), 892 deletions(-) create mode 100644 noita-proxy/tangled/src/common.rs create mode 100644 noita-proxy/tangled/src/connection_manager.rs create mode 100644 noita-proxy/tangled/src/helpers.rs delete mode 100644 noita-proxy/tangled/src/reactor.rs delete mode 100644 noita-proxy/tangled/src/util.rs diff --git a/noita-proxy/Cargo.lock b/noita-proxy/Cargo.lock index f7fa99dc..7246e664 100644 --- a/noita-proxy/Cargo.lock +++ b/noita-proxy/Cargo.lock @@ -93,7 +93,7 @@ dependencies = [ "bitflags 2.6.0", "cc", "cesu8", - "jni", + "jni 0.21.1", "jni-sys", "libc", "log", @@ -1703,6 +1703,20 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jni" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +dependencies = [ + "cesu8", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", +] + [[package]] name = "jni" version = "0.21.1" @@ -2038,12 +2052,31 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2302,6 +2335,16 @@ dependencies = [ "hmac", ] +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2474,6 +2517,7 @@ dependencies = [ "ring", "rustc-hash 2.0.0", "rustls", + "rustls-platform-verifier", "slab", "thiserror", "tinyvec", @@ -2544,6 +2588,19 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" +[[package]] +name = "rcgen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "rctree" version = "0.5.0" @@ -2821,6 +2878,33 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +[[package]] +name = "rustls-platform-verifier" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni 0.19.0", + "log", + "once_cell", + "rustls", + "rustls-native-certs 0.7.3", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-roots", + "winapi", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.102.7" @@ -2878,6 +2962,7 @@ dependencies = [ "core-foundation", "core-foundation-sys", "libc", + "num-bigint", "security-framework-sys", ] @@ -3215,13 +3300,18 @@ dependencies = [ [[package]] name = "tangled" -version = "0.2.0" +version = "0.3.0" dependencies = [ "bincode", "crossbeam", "dashmap", + "num-bigint", + "quinn", + "rcgen", "serde", "test-log", + "thiserror", + "tokio", "tracing", "tracing-subscriber", ] @@ -3476,6 +3566,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4006,7 +4097,7 @@ dependencies = [ "block2 0.5.1", "core-foundation", "home", - "jni", + "jni 0.21.1", "log", "ndk-context", "objc2 0.5.2", @@ -4488,6 +4579,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/noita-proxy/tangled/Cargo.toml b/noita-proxy/tangled/Cargo.toml index 8e807359..70f8c292 100644 --- a/noita-proxy/tangled/Cargo.toml +++ b/noita-proxy/tangled/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tangled" -version = "0.2.0" +version = "0.3.0" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/IntQuant/tangled" @@ -18,6 +18,11 @@ tracing = "0.1.36" dashmap = "6.0.1" serde = {features = ["derive"], version = "1.0.142"} bincode = "1.3.3" +quinn = "0.11.5" +num-bigint = "0.4.6" +rcgen = "0.13.1" +thiserror = "1.0.63" +tokio = "1.40.0" [dev-dependencies] test-log = { version = "0.2.11", default-features = false, features = ["trace"]} diff --git a/noita-proxy/tangled/src/common.rs b/noita-proxy/tangled/src/common.rs new file mode 100644 index 00000000..79036292 --- /dev/null +++ b/noita-proxy/tangled/src/common.rs @@ -0,0 +1,103 @@ +//! Various common public types. + +use std::{fmt::Display, time::Duration}; + +use serde::{Deserialize, Serialize}; + +/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings. +#[derive(Debug, Clone)] +pub struct Settings { + /// A single datagram will confirm at most this much messages. Default is 128. + pub confirm_max_per_message: usize, + /// How much time can elapse before another confirm is sent. + /// Confirms are also sent when enough messages are awaiting confirm. + /// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`. + /// Default: 1 second. + pub confirm_max_period: Duration, + /// Peers will be disconnected after this much time without any datagrams from them has passed. + /// Default: 10 seconds. + pub connection_timeout: Duration, +} + +/// Tells how reliable a message is. +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)] +pub enum Reliability { + /// Message will be delivered at most once. + Unreliable, + /// Message will be resent untill is's arrival will be confirmed. + /// Will be delivered at most once. + Reliable, +} + +pub enum Destination { + One(PeerId), + Broadcast, +} + +/// A value which refers to a specific peer. +/// Peer 0 is always the host. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub struct PeerId(pub u16); + +/// Possible network events, returned by `Peer.recv()`. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum NetworkEvent { + /// A new peer has connected. + PeerConnected(PeerId), + /// Peer has disconnected. + PeerDisconnected(PeerId), + /// Message has been received. + Message(Message), +} + +/// A message received from a peer. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Message { + /// Original peer who sent the message. + pub src: PeerId, + /// The data that has been sent. + pub data: Vec, +} + +/// Current peer state +#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)] +pub enum PeerState { + /// Waiting for connection. Switches to `Connected` right after id from the host has been acquired. + /// Note: hosts switches to 'Connected' basically instantly. + #[default] + PendingConnection, + /// Connected to host and ready to send/receive messages. + Connected, + /// No longer connected, won't reconnect. + Disconnected, +} + +impl Display for PeerState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PeerState::PendingConnection => write!(f, "Connection pending..."), + PeerState::Connected => write!(f, "Connected"), + PeerState::Disconnected => write!(f, "Disconnected"), + } + } +} + +impl PeerId { + pub const HOST: PeerId = PeerId(0); +} + +impl Display for PeerId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Default for Settings { + fn default() -> Self { + Self { + confirm_max_per_message: 128, + confirm_max_period: Duration::from_secs(1), + connection_timeout: Duration::from_secs(10), + } + } +} diff --git a/noita-proxy/tangled/src/connection_manager.rs b/noita-proxy/tangled/src/connection_manager.rs new file mode 100644 index 00000000..4db3bb8d --- /dev/null +++ b/noita-proxy/tangled/src/connection_manager.rs @@ -0,0 +1,250 @@ +use std::{ + io, + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use crossbeam::{ + atomic::AtomicCell, + channel::{unbounded, Receiver, Sender}, +}; +use dashmap::DashMap; +use quinn::{ + crypto::rustls::QuicClientConfig, + rustls::{ + self, + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + }, + ClientConfig, ConnectError, Connecting, ConnectionError, Endpoint, Incoming, ServerConfig, +}; +use thiserror::Error; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tracing::{error, info, warn}; + +use crate::{ + common::{Destination, NetworkEvent, PeerId, PeerState, Reliability, Settings}, + helpers::SkipServerVerification, +}; + +#[derive(Default)] +pub(crate) struct RemotePeer; + +#[derive(Debug, Error)] +enum DirectConnectionError { + #[error("QUIC Connection error: {0}")] + QUICConnectionError(#[from] ConnectionError), + #[error("Initial exchange failed")] + InitialExchangeFailed, +} + +struct DirectPeer { + my_id: PeerId, + remote_id: PeerId, + streams: (quinn::SendStream, quinn::RecvStream), +} + +impl DirectPeer { + async fn accept( + incoming: Incoming, + assigned_peer_id: PeerId, + ) -> Result { + let connection = incoming + .await + .inspect_err(|err| warn!("Failed to accept connection: {err}"))?; + + let mut sender = connection + .open_uni() + .await + .inspect_err(|err| warn!("Failed to get send stream: {err}"))?; + sender + .write_u16(assigned_peer_id.0) + .await + .map_err(|_err| DirectConnectionError::InitialExchangeFailed)?; + + let streams = connection.open_bi().await?; + + Ok(Self { + my_id: PeerId::HOST, + remote_id: assigned_peer_id, + streams, + }) + } + + async fn connect(connection: Connecting) -> Result { + let connection = connection + .await + .inspect_err(|err| warn!("Failed to initiate connection: {err}"))?; + + let mut receiver = connection.accept_uni().await?; + let peer_id = receiver + .read_u16() + .await + .map_err(|_err| DirectConnectionError::InitialExchangeFailed)?; + + let streams = connection.accept_bi().await?; + + Ok(Self { + my_id: PeerId(peer_id), + remote_id: PeerId::HOST, + streams, + }) + } +} + +type SeqId = u16; + +pub(crate) struct OutboundMessage { + pub dst: Destination, + pub reliability: Reliability, + pub data: Vec, +} + +pub(crate) type Channel = (Sender, Receiver); + +#[derive(Debug, Error)] +pub enum TangledInitError { + #[error("Could not create endpoint.\nReason: {0}")] + CouldNotCreateEndpoint(io::Error), + #[error("Could not connect to host.\nReason: {0}")] + CouldNotConnectToHost(ConnectError), +} + +pub(crate) struct Shared { + pub settings: Settings, + pub inbound_channel: Channel, + pub outbound_channel: Channel, + pub keep_alive: AtomicBool, + pub peer_state: AtomicCell, + pub remote_peers: DashMap, + pub host_addr: Option, + pub my_id: AtomicCell>, + // ConnectionManager-specific stuff + direct_peers: DashMap, +} + +impl Shared { + pub(crate) fn new(host_addr: Option, settings: Option) -> Self { + Self { + inbound_channel: unbounded(), + outbound_channel: unbounded(), + keep_alive: AtomicBool::new(true), + host_addr, + peer_state: Default::default(), + remote_peers: Default::default(), + my_id: AtomicCell::new(if host_addr.is_none() { + Some(PeerId(0)) + } else { + None + }), + settings: settings.unwrap_or_default(), + direct_peers: DashMap::default(), + } + } +} + +pub(crate) struct ConnectionManager { + shared: Arc, + endpoint: Endpoint, + host_conn: Option, + is_server: bool, +} + +impl ConnectionManager { + pub(crate) fn new(shared: Arc, addr: SocketAddr) -> Result { + let is_server = shared.host_addr.is_none(); + + let config = default_server_config(); + + let mut endpoint = if is_server { + Endpoint::server(config, addr).map_err(TangledInitError::CouldNotCreateEndpoint)? + } else { + Endpoint::client(addr).map_err(TangledInitError::CouldNotCreateEndpoint)? + }; + + endpoint.set_default_client_config(ClientConfig::new(Arc::new( + QuicClientConfig::try_from( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(), + ) + .unwrap(), + ))); + + Ok(Self { + shared, + is_server, + endpoint, + host_conn: None, + }) + } + + async fn accept_connections(shared: Arc, endpoint: Endpoint) { + let mut peer_id_counter = 1; + while shared.keep_alive.load(Ordering::Relaxed) { + let Some(incoming) = endpoint.accept().await else { + info!("Endpoint closed, stopping connection accepter task."); + return; + }; + match DirectPeer::accept(incoming, PeerId(peer_id_counter)).await { + Ok(direct_peer) => { + shared + .direct_peers + .insert(PeerId(peer_id_counter), direct_peer); + peer_id_counter += 1; + } + Err(err) => { + warn!("Failed to accept connection: {err}") + } + }; + } + } + + async fn astart(mut self, host_conn: Option) { + if let Some(host_conn) = host_conn { + match DirectPeer::connect(host_conn).await { + Ok(host_conn) => { + self.host_conn = Some(host_conn); + } + Err(err) => { + error!("Could not connect to host: {}", err); + self.shared.peer_state.store(PeerState::Disconnected); + return; + } + } + } + if self.is_server { + let endpoint = self.endpoint.clone(); + tokio::spawn(Self::accept_connections(self.shared.clone(), endpoint)); + info!("Started connection acceptor task"); + } + } + + pub(crate) fn start(self) -> Result<(), TangledInitError> { + let host_conn = self + .shared + .host_addr + .as_ref() + .map(|host_addr| { + self.endpoint + .connect(*host_addr, "tangled") + .map_err(TangledInitError::CouldNotConnectToHost) + }) + .transpose()?; + + tokio::spawn(self.astart(host_conn)); + Ok(()) + } +} + +fn default_server_config() -> ServerConfig { + let cert = rcgen::generate_simple_self_signed(vec!["tangled".into()]).unwrap(); + let cert_der = CertificateDer::from(cert.cert); + let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); + + let config = ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into()).unwrap(); + config +} diff --git a/noita-proxy/tangled/src/helpers.rs b/noita-proxy/tangled/src/helpers.rs new file mode 100644 index 00000000..363ffc3f --- /dev/null +++ b/noita-proxy/tangled/src/helpers.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use quinn::rustls::{ + self, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; + +#[derive(Debug)] +pub(crate) struct SkipServerVerification(Arc); + +impl SkipServerVerification { + pub(crate) fn new() -> Arc { + Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider()))) + } +} + +impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} diff --git a/noita-proxy/tangled/src/lib.rs b/noita-proxy/tangled/src/lib.rs index 8ee997c1..2cc24707 100644 --- a/noita-proxy/tangled/src/lib.rs +++ b/noita-proxy/tangled/src/lib.rs @@ -5,100 +5,32 @@ use std::{ io, net::{SocketAddr, UdpSocket}, sync::{atomic::AtomicBool, Arc}, + time::Duration, }; +use connection_manager::{ + ConnectionManager, OutboundMessage, RemotePeer, Shared, TangledInitError, +}; use crossbeam::{ self, atomic::AtomicCell, channel::{unbounded, Receiver, Sender}, }; +use dashmap::DashMap; pub use error::NetError; -use reactor::{Destination, RemotePeer, Shared}; -pub use reactor::{Reliability, Settings}; -use serde::{Deserialize, Serialize}; const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500 /// Maximum size of a message which fits into a single datagram. pub const MAX_MESSAGE_LEN: usize = DATAGRAM_MAX_LEN - 100; +mod common; +mod connection_manager; mod error; -mod reactor; -mod util; +mod helpers; -struct Datagram { - pub size: usize, - pub data: [u8; DATAGRAM_MAX_LEN], -} - -/// A value which refers to a specific peer. -/// Peer 0 is always the host. -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -pub struct PeerId(pub u16); - -impl PeerId { - pub const HOST: PeerId = PeerId(0); -} - -impl Display for PeerId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -type SeqId = u16; - -/// Possible network events, returned by `Peer.recv()`. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum NetworkEvent { - /// A new peer has connected. - PeerConnected(PeerId), - /// Peer has disconnected. - PeerDisconnected(PeerId), - /// Message has been received. - Message(Message), -} - -/// A message received from a peer. -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct Message { - /// Original peer who sent the message. - pub src: PeerId, - /// The data that has been sent. - pub data: Vec, -} - -struct OutboundMessage { - pub dst: Destination, - pub data: Vec, - pub reliability: Reliability, -} - -/// Current peer state -#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)] -pub enum PeerState { - /// Waiting for connection. Switches to `Connected` right after id from the host has been acquired. - /// Note: hosts switches to 'Connected' basically instantly. - #[default] - PendingConnection, - /// Connected to host and ready to send/receive messages. - Connected, - /// No longer connected, won't reconnect. - Disconnected, -} - -impl Display for PeerState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PeerState::PendingConnection => write!(f, "Connection pending..."), - PeerState::Connected => write!(f, "Connected"), - PeerState::Disconnected => write!(f, "Disconnected"), - } - } -} - -type Channel = (Sender, Receiver); +pub use common::*; /// Represents a network endpoint. Can be constructed in either `host` or `client` mode. /// @@ -113,24 +45,8 @@ impl Peer { bind_addr: SocketAddr, host_addr: Option, settings: Option, - ) -> io::Result { - let socket = UdpSocket::bind(bind_addr)?; - let shared = Arc::new(Shared { - socket, - inbound_channel: unbounded(), - outbound_channel: unbounded(), - keep_alive: AtomicBool::new(true), - host_addr, - peer_state: Default::default(), - remote_peers: Default::default(), - max_packets_per_second: 256, - my_id: AtomicCell::new(if host_addr.is_none() { - Some(PeerId(0)) - } else { - None - }), - settings: settings.unwrap_or_default(), - }); + ) -> Result { + let shared = Arc::new(Shared::new(host_addr, settings)); if host_addr.is_none() { shared.remote_peers.insert(PeerId(0), RemotePeer::default()); shared @@ -139,17 +55,23 @@ impl Peer { .send(NetworkEvent::PeerConnected(PeerId(0))) .unwrap(); } - reactor::Reactor::start(Arc::clone(&shared)); + ConnectionManager::new(Arc::clone(&shared), bind_addr)?.start()?; Ok(Peer { shared }) } /// Host at a specified `bind_addr`. - pub fn host(bind_addr: SocketAddr, settings: Option) -> io::Result { + pub fn host( + bind_addr: SocketAddr, + settings: Option, + ) -> Result { Self::new(bind_addr, None, settings) } /// Connect to a specified `host_addr`. - pub fn connect(host_addr: SocketAddr, settings: Option) -> io::Result { + pub fn connect( + host_addr: SocketAddr, + settings: Option, + ) -> Result { Self::new("0.0.0.0:0".parse().unwrap(), Some(host_addr), settings) } @@ -176,11 +98,6 @@ impl Peer { if data.len() > MAX_MESSAGE_LEN { return Err(NetError::MessageTooLong); } - if reliability == Reliability::Unreliable - && self.shared.outbound_channel.0.len() * 2 > self.shared.max_packets_per_second - { - return Err(NetError::Dropped); - } self.shared.outbound_channel.0.send(OutboundMessage { dst: destination, data, @@ -233,7 +150,7 @@ impl Drop for Peer { mod test { use std::{thread, time::Duration}; - use crate::{reactor::Settings, Message, NetworkEvent, Peer, PeerId, Reliability}; + use crate::{common::Message, NetworkEvent, Peer, PeerId, Reliability, Settings}; #[test_log::test] fn test_peer() { diff --git a/noita-proxy/tangled/src/reactor.rs b/noita-proxy/tangled/src/reactor.rs deleted file mode 100644 index 455d2913..00000000 --- a/noita-proxy/tangled/src/reactor.rs +++ /dev/null @@ -1,679 +0,0 @@ -use crate::{ - error::NetError, - util::{RateLimiter, RingSet}, - Channel, Message, NetworkEvent, OutboundMessage, PeerId, SeqId, -}; - -use super::{Datagram, PeerState, DATAGRAM_MAX_LEN}; -use crossbeam::{ - atomic::AtomicCell, - channel::{bounded, Receiver, Sender}, - select, -}; - -use dashmap::DashMap; -use serde::{Deserialize, Serialize}; -use std::{ - collections::{HashMap, VecDeque}, - error::Error, - io::Cursor, - net::{SocketAddr, UdpSocket}, - sync::{ - atomic::{AtomicBool, AtomicU16, Ordering::SeqCst}, - Arc, - }, - thread, - time::{Duration, Instant}, -}; -use tracing::{error, info, trace, warn}; - -/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings. -#[derive(Debug, Clone)] -pub struct Settings { - /// A single datagram will confirm at most this much messages. Default is 128. - pub confirm_max_per_message: usize, - /// How much time can elapse before another confirm is sent. - /// Confirms are also sent when enough messages are awaiting confirm. - /// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`. - /// Default: 1 second. - pub confirm_max_period: Duration, - /// Peers will be disconnected after this much time without any datagrams from them has passed. - /// Default: 10 seconds. - pub connection_timeout: Duration, -} - -impl Default for Settings { - fn default() -> Self { - Self { - confirm_max_per_message: 128, - confirm_max_period: Duration::from_secs(1), - connection_timeout: Duration::from_secs(10), - } - } -} - -pub(crate) struct Shared { - pub settings: Settings, - pub socket: UdpSocket, - pub inbound_channel: Channel, - pub outbound_channel: Channel, - pub keep_alive: AtomicBool, - pub peer_state: AtomicCell, - pub remote_peers: DashMap, - pub max_packets_per_second: usize, - pub host_addr: Option, - pub my_id: AtomicCell>, -} - -struct DirectPeer { - addr: SocketAddr, - outbound_pending: VecDeque, - resend_pending: VecDeque<(Instant, NetMessageNormal)>, - confirmed: RingSet, - rate_limit: RateLimiter, - seq_counter: AtomicU16, - recent_seq: RingSet, - pending_confirms: VecDeque, - last_confirm_sent: Instant, - last_seen: Instant, -} - -#[derive(Default)] -pub struct RemotePeer {} - -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] -pub enum Destination { - One(PeerId), - Broadcast, -} - -#[derive(Serialize, Deserialize, Clone)] -enum NetMessageVariant { - Login, - Normal(NetMessageNormal), -} - -/// Tells how reliable a message is. -#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)] -pub enum Reliability { - /// Message will be delivered at most once. - Unreliable, - /// Message will be resent untill is's arrival will be confirmed. - /// Will be delivered at most once. - Reliable, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -struct NetMessageNormal { - // Source that generated sequence id. - // Initially the same as origin_src, but can be changed when packet is retransmitted not as-is, e. g. when it is broadcasted. - src: PeerId, - // Original source. - origin_src: PeerId, - dst: Destination, - seq_id: SeqId, - reliability: Reliability, - inner: NetMessageInner, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -enum NetMessageInner { - RegDone { addr: SocketAddr }, - AddPeer { id: PeerId }, - DelPeer { id: PeerId }, - Confirm { confirmed_ids: Vec }, - Payload { data: Vec }, -} - -impl TryFrom for NetMessageVariant { - type Error = bincode::Error; - - fn try_from(datagram: Datagram) -> Result { - bincode::deserialize(&datagram.data[..datagram.size]) - } -} - -impl TryFrom<&NetMessageVariant> for Datagram { - type Error = bincode::Error; - - fn try_from(value: &NetMessageVariant) -> Result { - let mut data = Cursor::new([0; DATAGRAM_MAX_LEN]); - bincode::serialize_into(&mut data, value)?; - let size = data.position().try_into().unwrap(); - let data = data.into_inner(); - Ok(Datagram { data, size }) - } -} - -pub(crate) struct Reactor { - shared: Arc, - direct_peers: HashMap, -} - -type AddrDatagram = (SocketAddr, Datagram); - -impl Reactor { - fn add_peer(&self, id: PeerId) -> Result<(), NetError> { - self.shared.remote_peers.insert(id, RemotePeer::default()); - self.shared - .inbound_channel - .0 - .send(NetworkEvent::PeerConnected(id))?; - Ok(()) - } - - fn direct_broadcast( - &mut self, - src_id: PeerId, - msg: NetMessageInner, - reliability: Reliability, - ) -> Result<(), NetError> { - for (&peer_id, peer) in self.direct_peers.iter_mut() { - let new_seq_id = peer.seq_counter.fetch_add(1, SeqCst); - let new_msg = Self::wrap_packet_seq_id( - src_id, - src_id, - new_seq_id, - Destination::One(peer_id), - msg.clone(), - reliability, - )?; - Self::direct_send_peer(peer, new_msg)?; - } - Ok(()) - } - - fn direct_send(&mut self, id: PeerId, msg: NetMessageVariant) -> Result<(), NetError> { - let peer = self - .direct_peers - .get_mut(&id) - .ok_or(NetError::UnknownPeer)?; - Self::direct_send_peer(peer, msg) - } - - fn direct_send_peer(peer: &mut DirectPeer, msg: NetMessageVariant) -> Result<(), NetError> { - peer.outbound_pending.push_back(msg); - Ok(()) - } - - fn gen_peer_id(&mut self) -> Option { - (1..=u16::MAX) - .map(PeerId) - .find(|i| !self.shared.remote_peers.contains_key(i)) - } - - fn handle_inbound(&mut self, (incoming_addr, msg_raw): AddrDatagram) { - let msg = match NetMessageVariant::try_from(msg_raw) { - Ok(msg) => msg, - Err(err) => { - warn!("Error when converting to NetMessage: {}", err); - return; - } - }; - match self.shared.my_id.load() { - Some(id) => { - match msg { - NetMessageVariant::Login => { - if self.is_host() { - //TODO check this addr is not already registered - match self.gen_peer_id() { - Some(new_id) => { - self.add_peer(new_id).ok(); - let mut peer = DirectPeer::new( - incoming_addr, - self.shared.max_packets_per_second, - ); - peer.outbound_pending.push_back(NetMessageVariant::Normal( - NetMessageNormal { - src: id, - origin_src: id, - dst: Destination::One(new_id), - seq_id: u16::MAX, - inner: NetMessageInner::RegDone { - addr: incoming_addr, - }, - reliability: Reliability::Reliable, - }, - )); - self.direct_peers.insert(new_id, peer); - self.direct_broadcast( - id, - NetMessageInner::AddPeer { id: new_id }, - Reliability::Reliable, - ) - .ok(); - let shared = self.shared.clone(); - for re in shared.remote_peers.iter() { - let id = *re.key(); - if id != new_id { - self.wrap_packet( - id, - Destination::One(new_id), - NetMessageInner::AddPeer { id }, - Reliability::Reliable, - ) - .and_then(|msg| self.direct_send(new_id, msg)) - .ok(); - } - } - } - None => warn!("Out of ids"), - } - } else { - warn!("Not a host, registration attempt ignored"); - } - } - NetMessageVariant::Normal(msg) => { - match self.handle_inbound_normal(msg, incoming_addr, id) { - Ok(_) => {} - Err(NetError::Dropped) => {} - Err(err) => { - info!("Error while handling normal inbound message: {}", err) - } - } - } - } - } - None => match msg { - NetMessageVariant::Normal(NetMessageNormal { - inner: NetMessageInner::RegDone { addr: _ }, - dst, - src, - .. - }) => { - let expected_host_addr = self - .shared - .host_addr - .expect("Can't have both my_id and host_addr be None"); - if incoming_addr == expected_host_addr && src == PeerId(0) { - if let Destination::One(id) = dst { - self.shared.my_id.store(Some(id)); - self.add_peer(PeerId(0)).ok(); - self.shared.peer_state.store(PeerState::Connected); - } else { - warn!("Malformed registration message"); - } - } else { - warn!("Registration message recieved not from the right address ({}, {} expected)", incoming_addr, expected_host_addr); - } - } - _ => warn!("Message ignored as registration is not done yet"), - }, - } - } - - fn handle_inbound_normal( - &mut self, - msg: NetMessageNormal, - _incoming_addr: SocketAddr, - my_id: PeerId, - ) -> Result<(), NetError> { - let peer = self.direct_peers.get_mut(&msg.src); - if peer - .as_ref() - .map_or(true, |peer| peer.recent_seq.contains(&msg.seq_id)) - { - return Err(NetError::Dropped); - } - { - let peer = peer.expect("Expected to exist"); - peer.recent_seq.add(msg.seq_id); //TODO backpressure - peer.pending_confirms.push_back(msg.seq_id); - peer.last_seen = Instant::now() - } - - if Destination::One(my_id) == msg.dst || msg.dst == Destination::Broadcast { - // TODO eliminate this clone - match msg.inner.clone() { - NetMessageInner::RegDone { addr: _ } => { - warn!("Already registered, request ignored"); - } - NetMessageInner::AddPeer { id } => { - if !self.is_host() { - self.add_peer(id).ok(); - info!("Peer {} added", id); - } - } - NetMessageInner::DelPeer { id } => { - if !self.is_host() { - self.del_peer(id).ok(); - info!("Peer {} removed", id); - } - } - NetMessageInner::Confirm { confirmed_ids } => { - if let Some(peer) = self.direct_peers.get_mut(&msg.src) { - for id in confirmed_ids { - peer.confirmed.add(id); - } - } - } - NetMessageInner::Payload { data } => { - self.shared - .inbound_channel - .0 - .send(NetworkEvent::Message(Message { - src: msg.origin_src, - data, - }))?; - } - } - } - if self.is_host() && Destination::One(my_id) != msg.dst { - match msg.dst { - Destination::One(dst) => { - let new_msg = - self.wrap_packet(dst, Destination::One(dst), msg.inner, msg.reliability)?; - self.direct_send(dst, new_msg)?; - } - Destination::Broadcast => { - let mut buf = Vec::new(); - for peer in &self.direct_peers { - if *peer.0 == msg.src { - continue; - } - let seq_id = self.next_seq_id_for_peer(*peer.0)?; - if let Ok(wrapped_msg) = Self::wrap_packet_seq_id( - PeerId(0), - msg.origin_src, - seq_id, - Destination::One(*peer.0), - msg.inner.clone(), - msg.reliability, - ) { - buf.push((*peer.0, wrapped_msg)); - } - } - for (peer_id, wrapped_msg) in buf { - self.direct_send(peer_id, wrapped_msg).ok(); - } - } - } - } - - Ok(()) - } - - fn del_peer(&mut self, id: PeerId) -> Result<(), NetError> { - self.shared.remote_peers.remove(&id); - self.shared - .inbound_channel - .0 - .send(NetworkEvent::PeerDisconnected(id))?; - Ok(()) - } - - fn handle_outbound(&mut self, msg: OutboundMessage) -> Result<(), NetError> { - let dst = msg.dst; - if self.is_host() { - match dst { - Destination::One(id) => { - let net_msg = self.wrap_packet( - id, - dst, - NetMessageInner::Payload { data: msg.data }, - msg.reliability, - )?; - self.direct_send(id, net_msg)?; - } - Destination::Broadcast => self.direct_broadcast( - PeerId(0), - NetMessageInner::Payload { data: msg.data }, - msg.reliability, - )?, - } - } else { - let net_msg = self.wrap_packet( - PeerId(0), - dst, - NetMessageInner::Payload { data: msg.data }, - msg.reliability, - )?; - self.direct_send(PeerId(0), net_msg)?; - } - Ok(()) - } - - pub fn is_host(&self) -> bool { - self.shared.host_addr.is_none() - } - - pub fn next_seq_id_for_peer(&self, peer_id: PeerId) -> Result { - Ok(self - .direct_peers - .get(&peer_id) - .or_else(|| { - if !self.is_host() { - self.direct_peers.get(&PeerId(0)) - } else { - None - } - }) - .ok_or(NetError::UnknownPeer)? - .seq_counter - .fetch_add(1, SeqCst)) - } - - fn run(mut self, inbound_r: Receiver) -> Result<(), Box> { - while self.shared.keep_alive.load(SeqCst) { - select! { - recv(inbound_r) -> addr_msg => self.handle_inbound(addr_msg?), - recv(self.shared.outbound_channel.1) -> msg => {self.handle_outbound(msg?).ok();} - default => {thread::sleep(Duration::from_micros(100));} - } - let mut dc = Vec::new(); - self.direct_peers.retain(|&k, v| { - let stays = v.last_seen.elapsed() < self.shared.settings.connection_timeout; - if !stays { - dc.push(k); - } - stays - }); - if self.is_host() { - for peer_id in dc { - let src_id = self.shared.my_id.load().unwrap(); // Should always be PeerId(0) - assert_eq!(src_id, PeerId(0)); - self.direct_broadcast( - src_id, - NetMessageInner::DelPeer { id: peer_id }, - Reliability::Reliable, - )?; - self.del_peer(peer_id).ok(); - info!("[Host] Peer {} removed", peer_id); - } - } - if !self.is_host() && self.direct_peers.is_empty() { - self.shared.peer_state.store(PeerState::Disconnected); - self.shared.keep_alive.store(false, SeqCst); - } - 'peers: for (&id, peer) in self.direct_peers.iter_mut() { - let resend_in = Instant::now() + Duration::from_secs(1); - - if let Some(my_id) = self.shared.my_id.load() { - if peer.last_confirm_sent.elapsed() > self.shared.settings.confirm_max_period - || peer.pending_confirms.len() - > self.shared.settings.confirm_max_per_message - { - peer.last_confirm_sent = Instant::now(); - let max_per_message = self.shared.settings.confirm_max_per_message; - let mut confirmed_ids = Vec::with_capacity(max_per_message); - while let Some(confirm) = peer.pending_confirms.pop_front() { - confirmed_ids.push(confirm); - if confirmed_ids.len() == max_per_message { - break; - } - } - peer.resend_pending.push_front(( - Instant::now(), - NetMessageNormal { - src: my_id, - origin_src: my_id, - dst: Destination::One(id), - seq_id: peer.seq_counter.fetch_add(1, SeqCst), - reliability: Reliability::Reliable, - inner: NetMessageInner::Confirm { confirmed_ids }, - }, - )) - } - } - - while peer - .resend_pending - .front() - .map_or(false, |x| x.0 < Instant::now()) - { - let (moment, msg) = peer - .resend_pending - .pop_front() - .expect("Checked that deque is not empty"); - - if !peer.confirmed.contains(&msg.seq_id) { - if !peer.rate_limit.get_token() { - peer.resend_pending.push_front((moment, msg)); - continue 'peers; - } - peer.resend_pending.push_back((resend_in, msg.clone())); - trace!("Sent {:?} to {}", msg, peer.addr,); - let datagram = Datagram::try_from(&NetMessageVariant::Normal(msg)).unwrap(); - trace!("size: {}", datagram.size); - self.shared - .socket - .send_to(&datagram.data[..datagram.size], peer.addr) - .expect("Could not send"); - } - } - - while !peer.outbound_pending.is_empty() && peer.rate_limit.get_token() { - let msg = peer - .outbound_pending - .pop_front() - .expect("Checked that deque is not empty"); - if let NetMessageVariant::Normal(ref msg) = msg { - if msg.reliability == Reliability::Reliable { - peer.resend_pending.push_back((resend_in, msg.clone())); - } - } - let datagram = Datagram::try_from(&msg).unwrap(); - trace!("sent msg size: {}", datagram.size); - self.shared - .socket - .send_to(&datagram.data[..datagram.size], peer.addr) - .expect("Could not send"); - } - } - } - Ok(()) - } - - fn run_pipe( - shared: Arc, - sender: Sender<(SocketAddr, Datagram)>, - ) -> Result<(), Box> { - while shared.keep_alive.load(SeqCst) { - let mut buf = [0u8; DATAGRAM_MAX_LEN]; - match shared.socket.recv_from(&mut buf) { - Ok((len, addr)) => sender - .send(( - addr, - Datagram { - size: len, - data: buf, - }, - )) - .map_err(Box::new)?, - //Err(err) - // if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut => { - //} - Err(err) => return Err(Box::new(err)), - } - } - Ok(()) - } - - pub(crate) fn start(shared: Arc) { - let mut me = Reactor { - shared, - direct_peers: Default::default(), - }; - if !me.is_host() { - me.direct_peers.insert( - PeerId(0), - DirectPeer::new( - me.shared - .host_addr - .expect("Can't be a client without a host addr"), - me.shared.max_packets_per_second, - ), - ); - me.direct_send(PeerId(0), NetMessageVariant::Login).unwrap(); - } - if me.is_host() { - me.shared.peer_state.store(PeerState::Connected); - } - let shared_c = Arc::clone(&me.shared); - let (inbound_s, inbound_r) = bounded(16); - thread::spawn(move || { - let shared_c_2 = Arc::clone(&shared_c); - if let Err(err) = Self::run_pipe(shared_c_2, inbound_s) { - shared_c.keep_alive.store(false, SeqCst); - shared_c.peer_state.store(PeerState::Disconnected); - error!("Reactor pipe error: {}", err); - } - }); - let shared_c = Arc::clone(&me.shared); - thread::spawn(move || { - if let Err(err) = me.run(inbound_r) { - shared_c.keep_alive.store(false, SeqCst); - shared_c.peer_state.store(PeerState::Disconnected); - error!("Reactor error: {}", err); - } - }); - } - - fn wrap_packet( - &self, - id: PeerId, - dst: Destination, - msg: NetMessageInner, - reliability: Reliability, - ) -> Result { - let seq_id = self.next_seq_id_for_peer(id)?; - let src = self.shared.my_id.load().expect("Should know own id by now"); - Self::wrap_packet_seq_id(src, src, seq_id, dst, msg, reliability) - } - - fn wrap_packet_seq_id( - src: PeerId, - origin_src: PeerId, - seq_id: SeqId, - dst: Destination, - msg: NetMessageInner, - reliability: Reliability, - ) -> Result { - Ok(NetMessageVariant::Normal(NetMessageNormal { - src, - origin_src, - dst, - seq_id, - inner: msg, - reliability, - })) - } -} - -impl DirectPeer { - fn new(incoming_addr: SocketAddr, rate_limit: usize) -> DirectPeer { - let now = Instant::now(); - DirectPeer { - addr: incoming_addr, - outbound_pending: Default::default(), - resend_pending: Default::default(), - confirmed: RingSet::new(1024), - rate_limit: RateLimiter::new(rate_limit, Duration::from_secs(1)), - seq_counter: AtomicU16::new(0), - recent_seq: RingSet::new(1024), - pending_confirms: VecDeque::new(), - last_confirm_sent: now, - last_seen: now, - } - } -} diff --git a/noita-proxy/tangled/src/util.rs b/noita-proxy/tangled/src/util.rs deleted file mode 100644 index 37c3e8c3..00000000 --- a/noita-proxy/tangled/src/util.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::{ - collections::{HashSet, VecDeque}, - hash::Hash, - time::{Duration, Instant}, -}; - -pub struct RateLimiter { - moments: VecDeque, - time: Duration, - limit: usize, -} - -impl RateLimiter { - pub fn new(limit: usize, time: Duration) -> Self { - Self { - moments: VecDeque::with_capacity(limit), - time, - limit, - } - } - pub fn get_token(&mut self) -> bool { - let now = Instant::now(); - while self - .moments - .front() - .map_or(false, |moment| now - *moment > self.time) - { - self.moments.pop_front(); - } - if self.moments.len() < self.limit { - self.moments.push_back(now); - true - } else { - false - } - } -} - -pub struct RingSet { - set: HashSet, - ring: VecDeque, - limit: usize, -} - -impl RingSet { - pub fn new(limit: usize) -> Self { - assert!(limit > 0); - Self { - set: HashSet::new(), - ring: VecDeque::with_capacity(limit), - limit, - } - } - - pub fn add(&mut self, key: Key) { - if !self.contains(&key) { - if self.ring.len() >= self.limit { - let element = self.ring.pop_front().expect("Deque has elements"); - self.set.remove(&element); - } - self.set.insert(key.clone()); - self.ring.push_back(key); - } - } - - pub fn contains(&self, key: &Key) -> bool { - self.set.contains(key) - } -} - -#[cfg(test)] -mod tests { - use std::{thread, time::Duration}; - - use super::{RateLimiter, RingSet}; - - #[test] - fn rate_limit() { - let duration = Duration::from_micros(100); - let mut limiter = RateLimiter::new(4, duration); - - for _ in 0..4 { - assert!(limiter.get_token()) - } - assert!(!limiter.get_token()); - thread::sleep(duration * 2); - assert!(limiter.get_token()); - } - - #[test] - fn ring_set() { - let mut set = RingSet::new(3); - set.add(1); - assert!(set.contains(&1)); - set.add(2); - assert!(set.contains(&1)); - assert!(set.contains(&2)); - assert!(!set.contains(&3)); - set.add(3); - set.add(3); - set.add(4); - assert!(!set.contains(&1)); - assert!(set.contains(&4)); - } -}