WIP tangled on QUIC

This commit is contained in:
IQuant 2024-09-05 23:00:12 +03:00
parent 9630a002b3
commit cf76733005
8 changed files with 543 additions and 892 deletions

106
noita-proxy/Cargo.lock generated
View file

@ -93,7 +93,7 @@ dependencies = [
"bitflags 2.6.0", "bitflags 2.6.0",
"cc", "cc",
"cesu8", "cesu8",
"jni", "jni 0.21.1",
"jni-sys", "jni-sys",
"libc", "libc",
"log", "log",
@ -1703,6 +1703,20 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "jni"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
dependencies = [
"cesu8",
"combine",
"jni-sys",
"log",
"thiserror",
"walkdir",
]
[[package]] [[package]]
name = "jni" name = "jni"
version = "0.21.1" version = "0.21.1"
@ -2038,12 +2052,31 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]] [[package]]
name = "num-conv" name = "num-conv"
version = "0.1.0" version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.19" version = "0.2.19"
@ -2302,6 +2335,16 @@ dependencies = [
"hmac", "hmac",
] ]
[[package]]
name = "pem"
version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae"
dependencies = [
"base64 0.22.1",
"serde",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.1" version = "2.3.1"
@ -2474,6 +2517,7 @@ dependencies = [
"ring", "ring",
"rustc-hash 2.0.0", "rustc-hash 2.0.0",
"rustls", "rustls",
"rustls-platform-verifier",
"slab", "slab",
"thiserror", "thiserror",
"tinyvec", "tinyvec",
@ -2544,6 +2588,19 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539"
[[package]]
name = "rcgen"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"pem",
"ring",
"rustls-pki-types",
"time",
"yasna",
]
[[package]] [[package]]
name = "rctree" name = "rctree"
version = "0.5.0" version = "0.5.0"
@ -2821,6 +2878,33 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]]
name = "rustls-platform-verifier"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni 0.19.0",
"log",
"once_cell",
"rustls",
"rustls-native-certs 0.7.3",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-roots",
"winapi",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.102.7" version = "0.102.7"
@ -2878,6 +2962,7 @@ dependencies = [
"core-foundation", "core-foundation",
"core-foundation-sys", "core-foundation-sys",
"libc", "libc",
"num-bigint",
"security-framework-sys", "security-framework-sys",
] ]
@ -3215,13 +3300,18 @@ dependencies = [
[[package]] [[package]]
name = "tangled" name = "tangled"
version = "0.2.0" version = "0.3.0"
dependencies = [ dependencies = [
"bincode", "bincode",
"crossbeam", "crossbeam",
"dashmap", "dashmap",
"num-bigint",
"quinn",
"rcgen",
"serde", "serde",
"test-log", "test-log",
"thiserror",
"tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -3476,6 +3566,7 @@ version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [ dependencies = [
"log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes", "tracing-attributes",
"tracing-core", "tracing-core",
@ -4006,7 +4097,7 @@ dependencies = [
"block2 0.5.1", "block2 0.5.1",
"core-foundation", "core-foundation",
"home", "home",
"jni", "jni 0.21.1",
"log", "log",
"ndk-context", "ndk-context",
"objc2 0.5.2", "objc2 0.5.2",
@ -4488,6 +4579,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9" checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9"
[[package]]
name = "yasna"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
dependencies = [
"time",
]
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.7.35" version = "0.7.35"

View file

@ -1,6 +1,6 @@
[package] [package]
name = "tangled" name = "tangled"
version = "0.2.0" version = "0.3.0"
edition = "2021" edition = "2021"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
repository = "https://github.com/IntQuant/tangled" repository = "https://github.com/IntQuant/tangled"
@ -18,6 +18,11 @@ tracing = "0.1.36"
dashmap = "6.0.1" dashmap = "6.0.1"
serde = {features = ["derive"], version = "1.0.142"} serde = {features = ["derive"], version = "1.0.142"}
bincode = "1.3.3" bincode = "1.3.3"
quinn = "0.11.5"
num-bigint = "0.4.6"
rcgen = "0.13.1"
thiserror = "1.0.63"
tokio = "1.40.0"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2.11", default-features = false, features = ["trace"]} test-log = { version = "0.2.11", default-features = false, features = ["trace"]}

View file

@ -0,0 +1,103 @@
//! Various common public types.
use std::{fmt::Display, time::Duration};
use serde::{Deserialize, Serialize};
/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings.
#[derive(Debug, Clone)]
pub struct Settings {
/// A single datagram will confirm at most this much messages. Default is 128.
pub confirm_max_per_message: usize,
/// How much time can elapse before another confirm is sent.
/// Confirms are also sent when enough messages are awaiting confirm.
/// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`.
/// Default: 1 second.
pub confirm_max_period: Duration,
/// Peers will be disconnected after this much time without any datagrams from them has passed.
/// Default: 10 seconds.
pub connection_timeout: Duration,
}
/// Tells how reliable a message is.
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)]
pub enum Reliability {
/// Message will be delivered at most once.
Unreliable,
/// Message will be resent untill is's arrival will be confirmed.
/// Will be delivered at most once.
Reliable,
}
pub enum Destination {
One(PeerId),
Broadcast,
}
/// A value which refers to a specific peer.
/// Peer 0 is always the host.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct PeerId(pub u16);
/// Possible network events, returned by `Peer.recv()`.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum NetworkEvent {
/// A new peer has connected.
PeerConnected(PeerId),
/// Peer has disconnected.
PeerDisconnected(PeerId),
/// Message has been received.
Message(Message),
}
/// A message received from a peer.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Message {
/// Original peer who sent the message.
pub src: PeerId,
/// The data that has been sent.
pub data: Vec<u8>,
}
/// Current peer state
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
pub enum PeerState {
/// Waiting for connection. Switches to `Connected` right after id from the host has been acquired.
/// Note: hosts switches to 'Connected' basically instantly.
#[default]
PendingConnection,
/// Connected to host and ready to send/receive messages.
Connected,
/// No longer connected, won't reconnect.
Disconnected,
}
impl Display for PeerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PeerState::PendingConnection => write!(f, "Connection pending..."),
PeerState::Connected => write!(f, "Connected"),
PeerState::Disconnected => write!(f, "Disconnected"),
}
}
}
impl PeerId {
pub const HOST: PeerId = PeerId(0);
}
impl Display for PeerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Default for Settings {
fn default() -> Self {
Self {
confirm_max_per_message: 128,
confirm_max_period: Duration::from_secs(1),
connection_timeout: Duration::from_secs(10),
}
}
}

View file

@ -0,0 +1,250 @@
use std::{
io,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use crossbeam::{
atomic::AtomicCell,
channel::{unbounded, Receiver, Sender},
};
use dashmap::DashMap;
use quinn::{
crypto::rustls::QuicClientConfig,
rustls::{
self,
pki_types::{CertificateDer, PrivatePkcs8KeyDer},
},
ClientConfig, ConnectError, Connecting, ConnectionError, Endpoint, Incoming, ServerConfig,
};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{error, info, warn};
use crate::{
common::{Destination, NetworkEvent, PeerId, PeerState, Reliability, Settings},
helpers::SkipServerVerification,
};
#[derive(Default)]
pub(crate) struct RemotePeer;
#[derive(Debug, Error)]
enum DirectConnectionError {
#[error("QUIC Connection error: {0}")]
QUICConnectionError(#[from] ConnectionError),
#[error("Initial exchange failed")]
InitialExchangeFailed,
}
struct DirectPeer {
my_id: PeerId,
remote_id: PeerId,
streams: (quinn::SendStream, quinn::RecvStream),
}
impl DirectPeer {
async fn accept(
incoming: Incoming,
assigned_peer_id: PeerId,
) -> Result<Self, DirectConnectionError> {
let connection = incoming
.await
.inspect_err(|err| warn!("Failed to accept connection: {err}"))?;
let mut sender = connection
.open_uni()
.await
.inspect_err(|err| warn!("Failed to get send stream: {err}"))?;
sender
.write_u16(assigned_peer_id.0)
.await
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
let streams = connection.open_bi().await?;
Ok(Self {
my_id: PeerId::HOST,
remote_id: assigned_peer_id,
streams,
})
}
async fn connect(connection: Connecting) -> Result<Self, DirectConnectionError> {
let connection = connection
.await
.inspect_err(|err| warn!("Failed to initiate connection: {err}"))?;
let mut receiver = connection.accept_uni().await?;
let peer_id = receiver
.read_u16()
.await
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
let streams = connection.accept_bi().await?;
Ok(Self {
my_id: PeerId(peer_id),
remote_id: PeerId::HOST,
streams,
})
}
}
type SeqId = u16;
pub(crate) struct OutboundMessage {
pub dst: Destination,
pub reliability: Reliability,
pub data: Vec<u8>,
}
pub(crate) type Channel<T> = (Sender<T>, Receiver<T>);
#[derive(Debug, Error)]
pub enum TangledInitError {
#[error("Could not create endpoint.\nReason: {0}")]
CouldNotCreateEndpoint(io::Error),
#[error("Could not connect to host.\nReason: {0}")]
CouldNotConnectToHost(ConnectError),
}
pub(crate) struct Shared {
pub settings: Settings,
pub inbound_channel: Channel<NetworkEvent>,
pub outbound_channel: Channel<OutboundMessage>,
pub keep_alive: AtomicBool,
pub peer_state: AtomicCell<PeerState>,
pub remote_peers: DashMap<PeerId, RemotePeer>,
pub host_addr: Option<SocketAddr>,
pub my_id: AtomicCell<Option<PeerId>>,
// ConnectionManager-specific stuff
direct_peers: DashMap<PeerId, DirectPeer>,
}
impl Shared {
pub(crate) fn new(host_addr: Option<SocketAddr>, settings: Option<Settings>) -> Self {
Self {
inbound_channel: unbounded(),
outbound_channel: unbounded(),
keep_alive: AtomicBool::new(true),
host_addr,
peer_state: Default::default(),
remote_peers: Default::default(),
my_id: AtomicCell::new(if host_addr.is_none() {
Some(PeerId(0))
} else {
None
}),
settings: settings.unwrap_or_default(),
direct_peers: DashMap::default(),
}
}
}
pub(crate) struct ConnectionManager {
shared: Arc<Shared>,
endpoint: Endpoint,
host_conn: Option<DirectPeer>,
is_server: bool,
}
impl ConnectionManager {
pub(crate) fn new(shared: Arc<Shared>, addr: SocketAddr) -> Result<Self, TangledInitError> {
let is_server = shared.host_addr.is_none();
let config = default_server_config();
let mut endpoint = if is_server {
Endpoint::server(config, addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
} else {
Endpoint::client(addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
};
endpoint.set_default_client_config(ClientConfig::new(Arc::new(
QuicClientConfig::try_from(
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth(),
)
.unwrap(),
)));
Ok(Self {
shared,
is_server,
endpoint,
host_conn: None,
})
}
async fn accept_connections(shared: Arc<Shared>, endpoint: Endpoint) {
let mut peer_id_counter = 1;
while shared.keep_alive.load(Ordering::Relaxed) {
let Some(incoming) = endpoint.accept().await else {
info!("Endpoint closed, stopping connection accepter task.");
return;
};
match DirectPeer::accept(incoming, PeerId(peer_id_counter)).await {
Ok(direct_peer) => {
shared
.direct_peers
.insert(PeerId(peer_id_counter), direct_peer);
peer_id_counter += 1;
}
Err(err) => {
warn!("Failed to accept connection: {err}")
}
};
}
}
async fn astart(mut self, host_conn: Option<Connecting>) {
if let Some(host_conn) = host_conn {
match DirectPeer::connect(host_conn).await {
Ok(host_conn) => {
self.host_conn = Some(host_conn);
}
Err(err) => {
error!("Could not connect to host: {}", err);
self.shared.peer_state.store(PeerState::Disconnected);
return;
}
}
}
if self.is_server {
let endpoint = self.endpoint.clone();
tokio::spawn(Self::accept_connections(self.shared.clone(), endpoint));
info!("Started connection acceptor task");
}
}
pub(crate) fn start(self) -> Result<(), TangledInitError> {
let host_conn = self
.shared
.host_addr
.as_ref()
.map(|host_addr| {
self.endpoint
.connect(*host_addr, "tangled")
.map_err(TangledInitError::CouldNotConnectToHost)
})
.transpose()?;
tokio::spawn(self.astart(host_conn));
Ok(())
}
}
fn default_server_config() -> ServerConfig {
let cert = rcgen::generate_simple_self_signed(vec!["tangled".into()]).unwrap();
let cert_der = CertificateDer::from(cert.cert);
let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
let config = ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into()).unwrap();
config
}

View file

@ -0,0 +1,60 @@
use std::sync::Arc;
use quinn::rustls::{
self,
pki_types::{CertificateDer, ServerName, UnixTime},
};
#[derive(Debug)]
pub(crate) struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
impl SkipServerVerification {
pub(crate) fn new() -> Arc<Self> {
Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
}
}
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}

View file

@ -5,100 +5,32 @@ use std::{
io, io,
net::{SocketAddr, UdpSocket}, net::{SocketAddr, UdpSocket},
sync::{atomic::AtomicBool, Arc}, sync::{atomic::AtomicBool, Arc},
time::Duration,
}; };
use connection_manager::{
ConnectionManager, OutboundMessage, RemotePeer, Shared, TangledInitError,
};
use crossbeam::{ use crossbeam::{
self, self,
atomic::AtomicCell, atomic::AtomicCell,
channel::{unbounded, Receiver, Sender}, channel::{unbounded, Receiver, Sender},
}; };
use dashmap::DashMap;
pub use error::NetError; pub use error::NetError;
use reactor::{Destination, RemotePeer, Shared};
pub use reactor::{Reliability, Settings};
use serde::{Deserialize, Serialize};
const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500 const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500
/// Maximum size of a message which fits into a single datagram. /// Maximum size of a message which fits into a single datagram.
pub const MAX_MESSAGE_LEN: usize = DATAGRAM_MAX_LEN - 100; pub const MAX_MESSAGE_LEN: usize = DATAGRAM_MAX_LEN - 100;
mod common;
mod connection_manager;
mod error; mod error;
mod reactor; mod helpers;
mod util;
struct Datagram { pub use common::*;
pub size: usize,
pub data: [u8; DATAGRAM_MAX_LEN],
}
/// A value which refers to a specific peer.
/// Peer 0 is always the host.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct PeerId(pub u16);
impl PeerId {
pub const HOST: PeerId = PeerId(0);
}
impl Display for PeerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
type SeqId = u16;
/// Possible network events, returned by `Peer.recv()`.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum NetworkEvent {
/// A new peer has connected.
PeerConnected(PeerId),
/// Peer has disconnected.
PeerDisconnected(PeerId),
/// Message has been received.
Message(Message),
}
/// A message received from a peer.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Message {
/// Original peer who sent the message.
pub src: PeerId,
/// The data that has been sent.
pub data: Vec<u8>,
}
struct OutboundMessage {
pub dst: Destination,
pub data: Vec<u8>,
pub reliability: Reliability,
}
/// Current peer state
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
pub enum PeerState {
/// Waiting for connection. Switches to `Connected` right after id from the host has been acquired.
/// Note: hosts switches to 'Connected' basically instantly.
#[default]
PendingConnection,
/// Connected to host and ready to send/receive messages.
Connected,
/// No longer connected, won't reconnect.
Disconnected,
}
impl Display for PeerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PeerState::PendingConnection => write!(f, "Connection pending..."),
PeerState::Connected => write!(f, "Connected"),
PeerState::Disconnected => write!(f, "Disconnected"),
}
}
}
type Channel<T> = (Sender<T>, Receiver<T>);
/// Represents a network endpoint. Can be constructed in either `host` or `client` mode. /// Represents a network endpoint. Can be constructed in either `host` or `client` mode.
/// ///
@ -113,24 +45,8 @@ impl Peer {
bind_addr: SocketAddr, bind_addr: SocketAddr,
host_addr: Option<SocketAddr>, host_addr: Option<SocketAddr>,
settings: Option<Settings>, settings: Option<Settings>,
) -> io::Result<Self> { ) -> Result<Self, TangledInitError> {
let socket = UdpSocket::bind(bind_addr)?; let shared = Arc::new(Shared::new(host_addr, settings));
let shared = Arc::new(Shared {
socket,
inbound_channel: unbounded(),
outbound_channel: unbounded(),
keep_alive: AtomicBool::new(true),
host_addr,
peer_state: Default::default(),
remote_peers: Default::default(),
max_packets_per_second: 256,
my_id: AtomicCell::new(if host_addr.is_none() {
Some(PeerId(0))
} else {
None
}),
settings: settings.unwrap_or_default(),
});
if host_addr.is_none() { if host_addr.is_none() {
shared.remote_peers.insert(PeerId(0), RemotePeer::default()); shared.remote_peers.insert(PeerId(0), RemotePeer::default());
shared shared
@ -139,17 +55,23 @@ impl Peer {
.send(NetworkEvent::PeerConnected(PeerId(0))) .send(NetworkEvent::PeerConnected(PeerId(0)))
.unwrap(); .unwrap();
} }
reactor::Reactor::start(Arc::clone(&shared)); ConnectionManager::new(Arc::clone(&shared), bind_addr)?.start()?;
Ok(Peer { shared }) Ok(Peer { shared })
} }
/// Host at a specified `bind_addr`. /// Host at a specified `bind_addr`.
pub fn host(bind_addr: SocketAddr, settings: Option<Settings>) -> io::Result<Self> { pub fn host(
bind_addr: SocketAddr,
settings: Option<Settings>,
) -> Result<Self, TangledInitError> {
Self::new(bind_addr, None, settings) Self::new(bind_addr, None, settings)
} }
/// Connect to a specified `host_addr`. /// Connect to a specified `host_addr`.
pub fn connect(host_addr: SocketAddr, settings: Option<Settings>) -> io::Result<Self> { pub fn connect(
host_addr: SocketAddr,
settings: Option<Settings>,
) -> Result<Self, TangledInitError> {
Self::new("0.0.0.0:0".parse().unwrap(), Some(host_addr), settings) Self::new("0.0.0.0:0".parse().unwrap(), Some(host_addr), settings)
} }
@ -176,11 +98,6 @@ impl Peer {
if data.len() > MAX_MESSAGE_LEN { if data.len() > MAX_MESSAGE_LEN {
return Err(NetError::MessageTooLong); return Err(NetError::MessageTooLong);
} }
if reliability == Reliability::Unreliable
&& self.shared.outbound_channel.0.len() * 2 > self.shared.max_packets_per_second
{
return Err(NetError::Dropped);
}
self.shared.outbound_channel.0.send(OutboundMessage { self.shared.outbound_channel.0.send(OutboundMessage {
dst: destination, dst: destination,
data, data,
@ -233,7 +150,7 @@ impl Drop for Peer {
mod test { mod test {
use std::{thread, time::Duration}; use std::{thread, time::Duration};
use crate::{reactor::Settings, Message, NetworkEvent, Peer, PeerId, Reliability}; use crate::{common::Message, NetworkEvent, Peer, PeerId, Reliability, Settings};
#[test_log::test] #[test_log::test]
fn test_peer() { fn test_peer() {

View file

@ -1,679 +0,0 @@
use crate::{
error::NetError,
util::{RateLimiter, RingSet},
Channel, Message, NetworkEvent, OutboundMessage, PeerId, SeqId,
};
use super::{Datagram, PeerState, DATAGRAM_MAX_LEN};
use crossbeam::{
atomic::AtomicCell,
channel::{bounded, Receiver, Sender},
select,
};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, VecDeque},
error::Error,
io::Cursor,
net::{SocketAddr, UdpSocket},
sync::{
atomic::{AtomicBool, AtomicU16, Ordering::SeqCst},
Arc,
},
thread,
time::{Duration, Instant},
};
use tracing::{error, info, trace, warn};
/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings.
#[derive(Debug, Clone)]
pub struct Settings {
/// A single datagram will confirm at most this much messages. Default is 128.
pub confirm_max_per_message: usize,
/// How much time can elapse before another confirm is sent.
/// Confirms are also sent when enough messages are awaiting confirm.
/// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`.
/// Default: 1 second.
pub confirm_max_period: Duration,
/// Peers will be disconnected after this much time without any datagrams from them has passed.
/// Default: 10 seconds.
pub connection_timeout: Duration,
}
impl Default for Settings {
fn default() -> Self {
Self {
confirm_max_per_message: 128,
confirm_max_period: Duration::from_secs(1),
connection_timeout: Duration::from_secs(10),
}
}
}
pub(crate) struct Shared {
pub settings: Settings,
pub socket: UdpSocket,
pub inbound_channel: Channel<NetworkEvent>,
pub outbound_channel: Channel<OutboundMessage>,
pub keep_alive: AtomicBool,
pub peer_state: AtomicCell<PeerState>,
pub remote_peers: DashMap<PeerId, RemotePeer>,
pub max_packets_per_second: usize,
pub host_addr: Option<SocketAddr>,
pub my_id: AtomicCell<Option<PeerId>>,
}
struct DirectPeer {
addr: SocketAddr,
outbound_pending: VecDeque<NetMessageVariant>,
resend_pending: VecDeque<(Instant, NetMessageNormal)>,
confirmed: RingSet<SeqId>,
rate_limit: RateLimiter,
seq_counter: AtomicU16,
recent_seq: RingSet<SeqId>,
pending_confirms: VecDeque<SeqId>,
last_confirm_sent: Instant,
last_seen: Instant,
}
#[derive(Default)]
pub struct RemotePeer {}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
pub enum Destination {
One(PeerId),
Broadcast,
}
#[derive(Serialize, Deserialize, Clone)]
enum NetMessageVariant {
Login,
Normal(NetMessageNormal),
}
/// Tells how reliable a message is.
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)]
pub enum Reliability {
/// Message will be delivered at most once.
Unreliable,
/// Message will be resent untill is's arrival will be confirmed.
/// Will be delivered at most once.
Reliable,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct NetMessageNormal {
// Source that generated sequence id.
// Initially the same as origin_src, but can be changed when packet is retransmitted not as-is, e. g. when it is broadcasted.
src: PeerId,
// Original source.
origin_src: PeerId,
dst: Destination,
seq_id: SeqId,
reliability: Reliability,
inner: NetMessageInner,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
enum NetMessageInner {
RegDone { addr: SocketAddr },
AddPeer { id: PeerId },
DelPeer { id: PeerId },
Confirm { confirmed_ids: Vec<SeqId> },
Payload { data: Vec<u8> },
}
impl TryFrom<Datagram> for NetMessageVariant {
type Error = bincode::Error;
fn try_from(datagram: Datagram) -> Result<Self, Self::Error> {
bincode::deserialize(&datagram.data[..datagram.size])
}
}
impl TryFrom<&NetMessageVariant> for Datagram {
type Error = bincode::Error;
fn try_from(value: &NetMessageVariant) -> Result<Self, Self::Error> {
let mut data = Cursor::new([0; DATAGRAM_MAX_LEN]);
bincode::serialize_into(&mut data, value)?;
let size = data.position().try_into().unwrap();
let data = data.into_inner();
Ok(Datagram { data, size })
}
}
pub(crate) struct Reactor {
shared: Arc<Shared>,
direct_peers: HashMap<PeerId, DirectPeer>,
}
type AddrDatagram = (SocketAddr, Datagram);
impl Reactor {
fn add_peer(&self, id: PeerId) -> Result<(), NetError> {
self.shared.remote_peers.insert(id, RemotePeer::default());
self.shared
.inbound_channel
.0
.send(NetworkEvent::PeerConnected(id))?;
Ok(())
}
fn direct_broadcast(
&mut self,
src_id: PeerId,
msg: NetMessageInner,
reliability: Reliability,
) -> Result<(), NetError> {
for (&peer_id, peer) in self.direct_peers.iter_mut() {
let new_seq_id = peer.seq_counter.fetch_add(1, SeqCst);
let new_msg = Self::wrap_packet_seq_id(
src_id,
src_id,
new_seq_id,
Destination::One(peer_id),
msg.clone(),
reliability,
)?;
Self::direct_send_peer(peer, new_msg)?;
}
Ok(())
}
fn direct_send(&mut self, id: PeerId, msg: NetMessageVariant) -> Result<(), NetError> {
let peer = self
.direct_peers
.get_mut(&id)
.ok_or(NetError::UnknownPeer)?;
Self::direct_send_peer(peer, msg)
}
fn direct_send_peer(peer: &mut DirectPeer, msg: NetMessageVariant) -> Result<(), NetError> {
peer.outbound_pending.push_back(msg);
Ok(())
}
fn gen_peer_id(&mut self) -> Option<PeerId> {
(1..=u16::MAX)
.map(PeerId)
.find(|i| !self.shared.remote_peers.contains_key(i))
}
fn handle_inbound(&mut self, (incoming_addr, msg_raw): AddrDatagram) {
let msg = match NetMessageVariant::try_from(msg_raw) {
Ok(msg) => msg,
Err(err) => {
warn!("Error when converting to NetMessage: {}", err);
return;
}
};
match self.shared.my_id.load() {
Some(id) => {
match msg {
NetMessageVariant::Login => {
if self.is_host() {
//TODO check this addr is not already registered
match self.gen_peer_id() {
Some(new_id) => {
self.add_peer(new_id).ok();
let mut peer = DirectPeer::new(
incoming_addr,
self.shared.max_packets_per_second,
);
peer.outbound_pending.push_back(NetMessageVariant::Normal(
NetMessageNormal {
src: id,
origin_src: id,
dst: Destination::One(new_id),
seq_id: u16::MAX,
inner: NetMessageInner::RegDone {
addr: incoming_addr,
},
reliability: Reliability::Reliable,
},
));
self.direct_peers.insert(new_id, peer);
self.direct_broadcast(
id,
NetMessageInner::AddPeer { id: new_id },
Reliability::Reliable,
)
.ok();
let shared = self.shared.clone();
for re in shared.remote_peers.iter() {
let id = *re.key();
if id != new_id {
self.wrap_packet(
id,
Destination::One(new_id),
NetMessageInner::AddPeer { id },
Reliability::Reliable,
)
.and_then(|msg| self.direct_send(new_id, msg))
.ok();
}
}
}
None => warn!("Out of ids"),
}
} else {
warn!("Not a host, registration attempt ignored");
}
}
NetMessageVariant::Normal(msg) => {
match self.handle_inbound_normal(msg, incoming_addr, id) {
Ok(_) => {}
Err(NetError::Dropped) => {}
Err(err) => {
info!("Error while handling normal inbound message: {}", err)
}
}
}
}
}
None => match msg {
NetMessageVariant::Normal(NetMessageNormal {
inner: NetMessageInner::RegDone { addr: _ },
dst,
src,
..
}) => {
let expected_host_addr = self
.shared
.host_addr
.expect("Can't have both my_id and host_addr be None");
if incoming_addr == expected_host_addr && src == PeerId(0) {
if let Destination::One(id) = dst {
self.shared.my_id.store(Some(id));
self.add_peer(PeerId(0)).ok();
self.shared.peer_state.store(PeerState::Connected);
} else {
warn!("Malformed registration message");
}
} else {
warn!("Registration message recieved not from the right address ({}, {} expected)", incoming_addr, expected_host_addr);
}
}
_ => warn!("Message ignored as registration is not done yet"),
},
}
}
fn handle_inbound_normal(
&mut self,
msg: NetMessageNormal,
_incoming_addr: SocketAddr,
my_id: PeerId,
) -> Result<(), NetError> {
let peer = self.direct_peers.get_mut(&msg.src);
if peer
.as_ref()
.map_or(true, |peer| peer.recent_seq.contains(&msg.seq_id))
{
return Err(NetError::Dropped);
}
{
let peer = peer.expect("Expected to exist");
peer.recent_seq.add(msg.seq_id); //TODO backpressure
peer.pending_confirms.push_back(msg.seq_id);
peer.last_seen = Instant::now()
}
if Destination::One(my_id) == msg.dst || msg.dst == Destination::Broadcast {
// TODO eliminate this clone
match msg.inner.clone() {
NetMessageInner::RegDone { addr: _ } => {
warn!("Already registered, request ignored");
}
NetMessageInner::AddPeer { id } => {
if !self.is_host() {
self.add_peer(id).ok();
info!("Peer {} added", id);
}
}
NetMessageInner::DelPeer { id } => {
if !self.is_host() {
self.del_peer(id).ok();
info!("Peer {} removed", id);
}
}
NetMessageInner::Confirm { confirmed_ids } => {
if let Some(peer) = self.direct_peers.get_mut(&msg.src) {
for id in confirmed_ids {
peer.confirmed.add(id);
}
}
}
NetMessageInner::Payload { data } => {
self.shared
.inbound_channel
.0
.send(NetworkEvent::Message(Message {
src: msg.origin_src,
data,
}))?;
}
}
}
if self.is_host() && Destination::One(my_id) != msg.dst {
match msg.dst {
Destination::One(dst) => {
let new_msg =
self.wrap_packet(dst, Destination::One(dst), msg.inner, msg.reliability)?;
self.direct_send(dst, new_msg)?;
}
Destination::Broadcast => {
let mut buf = Vec::new();
for peer in &self.direct_peers {
if *peer.0 == msg.src {
continue;
}
let seq_id = self.next_seq_id_for_peer(*peer.0)?;
if let Ok(wrapped_msg) = Self::wrap_packet_seq_id(
PeerId(0),
msg.origin_src,
seq_id,
Destination::One(*peer.0),
msg.inner.clone(),
msg.reliability,
) {
buf.push((*peer.0, wrapped_msg));
}
}
for (peer_id, wrapped_msg) in buf {
self.direct_send(peer_id, wrapped_msg).ok();
}
}
}
}
Ok(())
}
fn del_peer(&mut self, id: PeerId) -> Result<(), NetError> {
self.shared.remote_peers.remove(&id);
self.shared
.inbound_channel
.0
.send(NetworkEvent::PeerDisconnected(id))?;
Ok(())
}
fn handle_outbound(&mut self, msg: OutboundMessage) -> Result<(), NetError> {
let dst = msg.dst;
if self.is_host() {
match dst {
Destination::One(id) => {
let net_msg = self.wrap_packet(
id,
dst,
NetMessageInner::Payload { data: msg.data },
msg.reliability,
)?;
self.direct_send(id, net_msg)?;
}
Destination::Broadcast => self.direct_broadcast(
PeerId(0),
NetMessageInner::Payload { data: msg.data },
msg.reliability,
)?,
}
} else {
let net_msg = self.wrap_packet(
PeerId(0),
dst,
NetMessageInner::Payload { data: msg.data },
msg.reliability,
)?;
self.direct_send(PeerId(0), net_msg)?;
}
Ok(())
}
pub fn is_host(&self) -> bool {
self.shared.host_addr.is_none()
}
pub fn next_seq_id_for_peer(&self, peer_id: PeerId) -> Result<SeqId, NetError> {
Ok(self
.direct_peers
.get(&peer_id)
.or_else(|| {
if !self.is_host() {
self.direct_peers.get(&PeerId(0))
} else {
None
}
})
.ok_or(NetError::UnknownPeer)?
.seq_counter
.fetch_add(1, SeqCst))
}
fn run(mut self, inbound_r: Receiver<AddrDatagram>) -> Result<(), Box<dyn Error>> {
while self.shared.keep_alive.load(SeqCst) {
select! {
recv(inbound_r) -> addr_msg => self.handle_inbound(addr_msg?),
recv(self.shared.outbound_channel.1) -> msg => {self.handle_outbound(msg?).ok();}
default => {thread::sleep(Duration::from_micros(100));}
}
let mut dc = Vec::new();
self.direct_peers.retain(|&k, v| {
let stays = v.last_seen.elapsed() < self.shared.settings.connection_timeout;
if !stays {
dc.push(k);
}
stays
});
if self.is_host() {
for peer_id in dc {
let src_id = self.shared.my_id.load().unwrap(); // Should always be PeerId(0)
assert_eq!(src_id, PeerId(0));
self.direct_broadcast(
src_id,
NetMessageInner::DelPeer { id: peer_id },
Reliability::Reliable,
)?;
self.del_peer(peer_id).ok();
info!("[Host] Peer {} removed", peer_id);
}
}
if !self.is_host() && self.direct_peers.is_empty() {
self.shared.peer_state.store(PeerState::Disconnected);
self.shared.keep_alive.store(false, SeqCst);
}
'peers: for (&id, peer) in self.direct_peers.iter_mut() {
let resend_in = Instant::now() + Duration::from_secs(1);
if let Some(my_id) = self.shared.my_id.load() {
if peer.last_confirm_sent.elapsed() > self.shared.settings.confirm_max_period
|| peer.pending_confirms.len()
> self.shared.settings.confirm_max_per_message
{
peer.last_confirm_sent = Instant::now();
let max_per_message = self.shared.settings.confirm_max_per_message;
let mut confirmed_ids = Vec::with_capacity(max_per_message);
while let Some(confirm) = peer.pending_confirms.pop_front() {
confirmed_ids.push(confirm);
if confirmed_ids.len() == max_per_message {
break;
}
}
peer.resend_pending.push_front((
Instant::now(),
NetMessageNormal {
src: my_id,
origin_src: my_id,
dst: Destination::One(id),
seq_id: peer.seq_counter.fetch_add(1, SeqCst),
reliability: Reliability::Reliable,
inner: NetMessageInner::Confirm { confirmed_ids },
},
))
}
}
while peer
.resend_pending
.front()
.map_or(false, |x| x.0 < Instant::now())
{
let (moment, msg) = peer
.resend_pending
.pop_front()
.expect("Checked that deque is not empty");
if !peer.confirmed.contains(&msg.seq_id) {
if !peer.rate_limit.get_token() {
peer.resend_pending.push_front((moment, msg));
continue 'peers;
}
peer.resend_pending.push_back((resend_in, msg.clone()));
trace!("Sent {:?} to {}", msg, peer.addr,);
let datagram = Datagram::try_from(&NetMessageVariant::Normal(msg)).unwrap();
trace!("size: {}", datagram.size);
self.shared
.socket
.send_to(&datagram.data[..datagram.size], peer.addr)
.expect("Could not send");
}
}
while !peer.outbound_pending.is_empty() && peer.rate_limit.get_token() {
let msg = peer
.outbound_pending
.pop_front()
.expect("Checked that deque is not empty");
if let NetMessageVariant::Normal(ref msg) = msg {
if msg.reliability == Reliability::Reliable {
peer.resend_pending.push_back((resend_in, msg.clone()));
}
}
let datagram = Datagram::try_from(&msg).unwrap();
trace!("sent msg size: {}", datagram.size);
self.shared
.socket
.send_to(&datagram.data[..datagram.size], peer.addr)
.expect("Could not send");
}
}
}
Ok(())
}
fn run_pipe(
shared: Arc<Shared>,
sender: Sender<(SocketAddr, Datagram)>,
) -> Result<(), Box<dyn Error>> {
while shared.keep_alive.load(SeqCst) {
let mut buf = [0u8; DATAGRAM_MAX_LEN];
match shared.socket.recv_from(&mut buf) {
Ok((len, addr)) => sender
.send((
addr,
Datagram {
size: len,
data: buf,
},
))
.map_err(Box::new)?,
//Err(err)
// if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut => {
//}
Err(err) => return Err(Box::new(err)),
}
}
Ok(())
}
pub(crate) fn start(shared: Arc<Shared>) {
let mut me = Reactor {
shared,
direct_peers: Default::default(),
};
if !me.is_host() {
me.direct_peers.insert(
PeerId(0),
DirectPeer::new(
me.shared
.host_addr
.expect("Can't be a client without a host addr"),
me.shared.max_packets_per_second,
),
);
me.direct_send(PeerId(0), NetMessageVariant::Login).unwrap();
}
if me.is_host() {
me.shared.peer_state.store(PeerState::Connected);
}
let shared_c = Arc::clone(&me.shared);
let (inbound_s, inbound_r) = bounded(16);
thread::spawn(move || {
let shared_c_2 = Arc::clone(&shared_c);
if let Err(err) = Self::run_pipe(shared_c_2, inbound_s) {
shared_c.keep_alive.store(false, SeqCst);
shared_c.peer_state.store(PeerState::Disconnected);
error!("Reactor pipe error: {}", err);
}
});
let shared_c = Arc::clone(&me.shared);
thread::spawn(move || {
if let Err(err) = me.run(inbound_r) {
shared_c.keep_alive.store(false, SeqCst);
shared_c.peer_state.store(PeerState::Disconnected);
error!("Reactor error: {}", err);
}
});
}
fn wrap_packet(
&self,
id: PeerId,
dst: Destination,
msg: NetMessageInner,
reliability: Reliability,
) -> Result<NetMessageVariant, NetError> {
let seq_id = self.next_seq_id_for_peer(id)?;
let src = self.shared.my_id.load().expect("Should know own id by now");
Self::wrap_packet_seq_id(src, src, seq_id, dst, msg, reliability)
}
fn wrap_packet_seq_id(
src: PeerId,
origin_src: PeerId,
seq_id: SeqId,
dst: Destination,
msg: NetMessageInner,
reliability: Reliability,
) -> Result<NetMessageVariant, NetError> {
Ok(NetMessageVariant::Normal(NetMessageNormal {
src,
origin_src,
dst,
seq_id,
inner: msg,
reliability,
}))
}
}
impl DirectPeer {
fn new(incoming_addr: SocketAddr, rate_limit: usize) -> DirectPeer {
let now = Instant::now();
DirectPeer {
addr: incoming_addr,
outbound_pending: Default::default(),
resend_pending: Default::default(),
confirmed: RingSet::new(1024),
rate_limit: RateLimiter::new(rate_limit, Duration::from_secs(1)),
seq_counter: AtomicU16::new(0),
recent_seq: RingSet::new(1024),
pending_confirms: VecDeque::new(),
last_confirm_sent: now,
last_seen: now,
}
}
}

View file

@ -1,105 +0,0 @@
use std::{
collections::{HashSet, VecDeque},
hash::Hash,
time::{Duration, Instant},
};
pub struct RateLimiter {
moments: VecDeque<Instant>,
time: Duration,
limit: usize,
}
impl RateLimiter {
pub fn new(limit: usize, time: Duration) -> Self {
Self {
moments: VecDeque::with_capacity(limit),
time,
limit,
}
}
pub fn get_token(&mut self) -> bool {
let now = Instant::now();
while self
.moments
.front()
.map_or(false, |moment| now - *moment > self.time)
{
self.moments.pop_front();
}
if self.moments.len() < self.limit {
self.moments.push_back(now);
true
} else {
false
}
}
}
pub struct RingSet<Key: Hash + Eq + Clone> {
set: HashSet<Key>,
ring: VecDeque<Key>,
limit: usize,
}
impl<Key: Hash + Eq + Clone> RingSet<Key> {
pub fn new(limit: usize) -> Self {
assert!(limit > 0);
Self {
set: HashSet::new(),
ring: VecDeque::with_capacity(limit),
limit,
}
}
pub fn add(&mut self, key: Key) {
if !self.contains(&key) {
if self.ring.len() >= self.limit {
let element = self.ring.pop_front().expect("Deque has elements");
self.set.remove(&element);
}
self.set.insert(key.clone());
self.ring.push_back(key);
}
}
pub fn contains(&self, key: &Key) -> bool {
self.set.contains(key)
}
}
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use super::{RateLimiter, RingSet};
#[test]
fn rate_limit() {
let duration = Duration::from_micros(100);
let mut limiter = RateLimiter::new(4, duration);
for _ in 0..4 {
assert!(limiter.get_token())
}
assert!(!limiter.get_token());
thread::sleep(duration * 2);
assert!(limiter.get_token());
}
#[test]
fn ring_set() {
let mut set = RingSet::new(3);
set.add(1);
assert!(set.contains(&1));
set.add(2);
assert!(set.contains(&1));
assert!(set.contains(&2));
assert!(!set.contains(&3));
set.add(3);
set.add(3);
set.add(4);
assert!(!set.contains(&1));
assert!(set.contains(&4));
}
}