mirror of
https://github.com/IntQuant/noita_entangled_worlds.git
synced 2025-10-19 07:03:16 +00:00
WIP tangled on QUIC
This commit is contained in:
parent
9630a002b3
commit
cf76733005
8 changed files with 543 additions and 892 deletions
106
noita-proxy/Cargo.lock
generated
106
noita-proxy/Cargo.lock
generated
|
@ -93,7 +93,7 @@ dependencies = [
|
|||
"bitflags 2.6.0",
|
||||
"cc",
|
||||
"cesu8",
|
||||
"jni",
|
||||
"jni 0.21.1",
|
||||
"jni-sys",
|
||||
"libc",
|
||||
"log",
|
||||
|
@ -1703,6 +1703,20 @@ version = "1.0.11"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.21.1"
|
||||
|
@ -2038,12 +2052,31 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
|
@ -2302,6 +2335,16 @@ dependencies = [
|
|||
"hmac",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "3.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
|
@ -2474,6 +2517,7 @@ dependencies = [
|
|||
"ring",
|
||||
"rustc-hash 2.0.0",
|
||||
"rustls",
|
||||
"rustls-platform-verifier",
|
||||
"slab",
|
||||
"thiserror",
|
||||
"tinyvec",
|
||||
|
@ -2544,6 +2588,19 @@ version = "0.6.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539"
|
||||
|
||||
[[package]]
|
||||
name = "rcgen"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
|
||||
dependencies = [
|
||||
"pem",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"time",
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rctree"
|
||||
version = "0.5.0"
|
||||
|
@ -2821,6 +2878,33 @@ version = "1.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni 0.19.0",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs 0.7.3",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-roots",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.7"
|
||||
|
@ -2878,6 +2962,7 @@ dependencies = [
|
|||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"num-bigint",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
|
@ -3215,13 +3300,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tangled"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"crossbeam",
|
||||
"dashmap",
|
||||
"num-bigint",
|
||||
"quinn",
|
||||
"rcgen",
|
||||
"serde",
|
||||
"test-log",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
@ -3476,6 +3566,7 @@ version = "0.1.40"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
@ -4006,7 +4097,7 @@ dependencies = [
|
|||
"block2 0.5.1",
|
||||
"core-foundation",
|
||||
"home",
|
||||
"jni",
|
||||
"jni 0.21.1",
|
||||
"log",
|
||||
"ndk-context",
|
||||
"objc2 0.5.2",
|
||||
|
@ -4488,6 +4579,15 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec7a2a501ed189703dba8b08142f057e887dfc4b2cc4db2d343ac6376ba3e0b9"
|
||||
|
||||
[[package]]
|
||||
name = "yasna"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
|
||||
dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.35"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "tangled"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/IntQuant/tangled"
|
||||
|
@ -18,6 +18,11 @@ tracing = "0.1.36"
|
|||
dashmap = "6.0.1"
|
||||
serde = {features = ["derive"], version = "1.0.142"}
|
||||
bincode = "1.3.3"
|
||||
quinn = "0.11.5"
|
||||
num-bigint = "0.4.6"
|
||||
rcgen = "0.13.1"
|
||||
thiserror = "1.0.63"
|
||||
tokio = "1.40.0"
|
||||
|
||||
[dev-dependencies]
|
||||
test-log = { version = "0.2.11", default-features = false, features = ["trace"]}
|
||||
|
|
103
noita-proxy/tangled/src/common.rs
Normal file
103
noita-proxy/tangled/src/common.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
//! Various common public types.
|
||||
|
||||
use std::{fmt::Display, time::Duration};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Settings {
|
||||
/// A single datagram will confirm at most this much messages. Default is 128.
|
||||
pub confirm_max_per_message: usize,
|
||||
/// How much time can elapse before another confirm is sent.
|
||||
/// Confirms are also sent when enough messages are awaiting confirm.
|
||||
/// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`.
|
||||
/// Default: 1 second.
|
||||
pub confirm_max_period: Duration,
|
||||
/// Peers will be disconnected after this much time without any datagrams from them has passed.
|
||||
/// Default: 10 seconds.
|
||||
pub connection_timeout: Duration,
|
||||
}
|
||||
|
||||
/// Tells how reliable a message is.
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)]
|
||||
pub enum Reliability {
|
||||
/// Message will be delivered at most once.
|
||||
Unreliable,
|
||||
/// Message will be resent untill is's arrival will be confirmed.
|
||||
/// Will be delivered at most once.
|
||||
Reliable,
|
||||
}
|
||||
|
||||
pub enum Destination {
|
||||
One(PeerId),
|
||||
Broadcast,
|
||||
}
|
||||
|
||||
/// A value which refers to a specific peer.
|
||||
/// Peer 0 is always the host.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct PeerId(pub u16);
|
||||
|
||||
/// Possible network events, returned by `Peer.recv()`.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum NetworkEvent {
|
||||
/// A new peer has connected.
|
||||
PeerConnected(PeerId),
|
||||
/// Peer has disconnected.
|
||||
PeerDisconnected(PeerId),
|
||||
/// Message has been received.
|
||||
Message(Message),
|
||||
}
|
||||
|
||||
/// A message received from a peer.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct Message {
|
||||
/// Original peer who sent the message.
|
||||
pub src: PeerId,
|
||||
/// The data that has been sent.
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Current peer state
|
||||
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum PeerState {
|
||||
/// Waiting for connection. Switches to `Connected` right after id from the host has been acquired.
|
||||
/// Note: hosts switches to 'Connected' basically instantly.
|
||||
#[default]
|
||||
PendingConnection,
|
||||
/// Connected to host and ready to send/receive messages.
|
||||
Connected,
|
||||
/// No longer connected, won't reconnect.
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl Display for PeerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PeerState::PendingConnection => write!(f, "Connection pending..."),
|
||||
PeerState::Connected => write!(f, "Connected"),
|
||||
PeerState::Disconnected => write!(f, "Disconnected"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerId {
|
||||
pub const HOST: PeerId = PeerId(0);
|
||||
}
|
||||
|
||||
impl Display for PeerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
confirm_max_per_message: 128,
|
||||
confirm_max_period: Duration::from_secs(1),
|
||||
connection_timeout: Duration::from_secs(10),
|
||||
}
|
||||
}
|
||||
}
|
250
noita-proxy/tangled/src/connection_manager.rs
Normal file
250
noita-proxy/tangled/src/connection_manager.rs
Normal file
|
@ -0,0 +1,250 @@
|
|||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use crossbeam::{
|
||||
atomic::AtomicCell,
|
||||
channel::{unbounded, Receiver, Sender},
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use quinn::{
|
||||
crypto::rustls::QuicClientConfig,
|
||||
rustls::{
|
||||
self,
|
||||
pki_types::{CertificateDer, PrivatePkcs8KeyDer},
|
||||
},
|
||||
ClientConfig, ConnectError, Connecting, ConnectionError, Endpoint, Incoming, ServerConfig,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::{
|
||||
common::{Destination, NetworkEvent, PeerId, PeerState, Reliability, Settings},
|
||||
helpers::SkipServerVerification,
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct RemotePeer;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum DirectConnectionError {
|
||||
#[error("QUIC Connection error: {0}")]
|
||||
QUICConnectionError(#[from] ConnectionError),
|
||||
#[error("Initial exchange failed")]
|
||||
InitialExchangeFailed,
|
||||
}
|
||||
|
||||
struct DirectPeer {
|
||||
my_id: PeerId,
|
||||
remote_id: PeerId,
|
||||
streams: (quinn::SendStream, quinn::RecvStream),
|
||||
}
|
||||
|
||||
impl DirectPeer {
|
||||
async fn accept(
|
||||
incoming: Incoming,
|
||||
assigned_peer_id: PeerId,
|
||||
) -> Result<Self, DirectConnectionError> {
|
||||
let connection = incoming
|
||||
.await
|
||||
.inspect_err(|err| warn!("Failed to accept connection: {err}"))?;
|
||||
|
||||
let mut sender = connection
|
||||
.open_uni()
|
||||
.await
|
||||
.inspect_err(|err| warn!("Failed to get send stream: {err}"))?;
|
||||
sender
|
||||
.write_u16(assigned_peer_id.0)
|
||||
.await
|
||||
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
|
||||
|
||||
let streams = connection.open_bi().await?;
|
||||
|
||||
Ok(Self {
|
||||
my_id: PeerId::HOST,
|
||||
remote_id: assigned_peer_id,
|
||||
streams,
|
||||
})
|
||||
}
|
||||
|
||||
async fn connect(connection: Connecting) -> Result<Self, DirectConnectionError> {
|
||||
let connection = connection
|
||||
.await
|
||||
.inspect_err(|err| warn!("Failed to initiate connection: {err}"))?;
|
||||
|
||||
let mut receiver = connection.accept_uni().await?;
|
||||
let peer_id = receiver
|
||||
.read_u16()
|
||||
.await
|
||||
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
|
||||
|
||||
let streams = connection.accept_bi().await?;
|
||||
|
||||
Ok(Self {
|
||||
my_id: PeerId(peer_id),
|
||||
remote_id: PeerId::HOST,
|
||||
streams,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type SeqId = u16;
|
||||
|
||||
pub(crate) struct OutboundMessage {
|
||||
pub dst: Destination,
|
||||
pub reliability: Reliability,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
pub(crate) type Channel<T> = (Sender<T>, Receiver<T>);
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TangledInitError {
|
||||
#[error("Could not create endpoint.\nReason: {0}")]
|
||||
CouldNotCreateEndpoint(io::Error),
|
||||
#[error("Could not connect to host.\nReason: {0}")]
|
||||
CouldNotConnectToHost(ConnectError),
|
||||
}
|
||||
|
||||
pub(crate) struct Shared {
|
||||
pub settings: Settings,
|
||||
pub inbound_channel: Channel<NetworkEvent>,
|
||||
pub outbound_channel: Channel<OutboundMessage>,
|
||||
pub keep_alive: AtomicBool,
|
||||
pub peer_state: AtomicCell<PeerState>,
|
||||
pub remote_peers: DashMap<PeerId, RemotePeer>,
|
||||
pub host_addr: Option<SocketAddr>,
|
||||
pub my_id: AtomicCell<Option<PeerId>>,
|
||||
// ConnectionManager-specific stuff
|
||||
direct_peers: DashMap<PeerId, DirectPeer>,
|
||||
}
|
||||
|
||||
impl Shared {
|
||||
pub(crate) fn new(host_addr: Option<SocketAddr>, settings: Option<Settings>) -> Self {
|
||||
Self {
|
||||
inbound_channel: unbounded(),
|
||||
outbound_channel: unbounded(),
|
||||
keep_alive: AtomicBool::new(true),
|
||||
host_addr,
|
||||
peer_state: Default::default(),
|
||||
remote_peers: Default::default(),
|
||||
my_id: AtomicCell::new(if host_addr.is_none() {
|
||||
Some(PeerId(0))
|
||||
} else {
|
||||
None
|
||||
}),
|
||||
settings: settings.unwrap_or_default(),
|
||||
direct_peers: DashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ConnectionManager {
|
||||
shared: Arc<Shared>,
|
||||
endpoint: Endpoint,
|
||||
host_conn: Option<DirectPeer>,
|
||||
is_server: bool,
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
pub(crate) fn new(shared: Arc<Shared>, addr: SocketAddr) -> Result<Self, TangledInitError> {
|
||||
let is_server = shared.host_addr.is_none();
|
||||
|
||||
let config = default_server_config();
|
||||
|
||||
let mut endpoint = if is_server {
|
||||
Endpoint::server(config, addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
|
||||
} else {
|
||||
Endpoint::client(addr).map_err(TangledInitError::CouldNotCreateEndpoint)?
|
||||
};
|
||||
|
||||
endpoint.set_default_client_config(ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(SkipServerVerification::new())
|
||||
.with_no_client_auth(),
|
||||
)
|
||||
.unwrap(),
|
||||
)));
|
||||
|
||||
Ok(Self {
|
||||
shared,
|
||||
is_server,
|
||||
endpoint,
|
||||
host_conn: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn accept_connections(shared: Arc<Shared>, endpoint: Endpoint) {
|
||||
let mut peer_id_counter = 1;
|
||||
while shared.keep_alive.load(Ordering::Relaxed) {
|
||||
let Some(incoming) = endpoint.accept().await else {
|
||||
info!("Endpoint closed, stopping connection accepter task.");
|
||||
return;
|
||||
};
|
||||
match DirectPeer::accept(incoming, PeerId(peer_id_counter)).await {
|
||||
Ok(direct_peer) => {
|
||||
shared
|
||||
.direct_peers
|
||||
.insert(PeerId(peer_id_counter), direct_peer);
|
||||
peer_id_counter += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to accept connection: {err}")
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn astart(mut self, host_conn: Option<Connecting>) {
|
||||
if let Some(host_conn) = host_conn {
|
||||
match DirectPeer::connect(host_conn).await {
|
||||
Ok(host_conn) => {
|
||||
self.host_conn = Some(host_conn);
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Could not connect to host: {}", err);
|
||||
self.shared.peer_state.store(PeerState::Disconnected);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.is_server {
|
||||
let endpoint = self.endpoint.clone();
|
||||
tokio::spawn(Self::accept_connections(self.shared.clone(), endpoint));
|
||||
info!("Started connection acceptor task");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn start(self) -> Result<(), TangledInitError> {
|
||||
let host_conn = self
|
||||
.shared
|
||||
.host_addr
|
||||
.as_ref()
|
||||
.map(|host_addr| {
|
||||
self.endpoint
|
||||
.connect(*host_addr, "tangled")
|
||||
.map_err(TangledInitError::CouldNotConnectToHost)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
tokio::spawn(self.astart(host_conn));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_server_config() -> ServerConfig {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["tangled".into()]).unwrap();
|
||||
let cert_der = CertificateDer::from(cert.cert);
|
||||
let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
|
||||
|
||||
let config = ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into()).unwrap();
|
||||
config
|
||||
}
|
60
noita-proxy/tangled/src/helpers.rs
Normal file
60
noita-proxy/tangled/src/helpers.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use quinn::rustls::{
|
||||
self,
|
||||
pki_types::{CertificateDer, ServerName, UnixTime},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
|
||||
|
||||
impl SkipServerVerification {
|
||||
pub(crate) fn new() -> Arc<Self> {
|
||||
Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls12_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&self.0.signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&self.0.signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
self.0.signature_verification_algorithms.supported_schemes()
|
||||
}
|
||||
}
|
|
@ -5,100 +5,32 @@ use std::{
|
|||
io,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use connection_manager::{
|
||||
ConnectionManager, OutboundMessage, RemotePeer, Shared, TangledInitError,
|
||||
};
|
||||
use crossbeam::{
|
||||
self,
|
||||
atomic::AtomicCell,
|
||||
channel::{unbounded, Receiver, Sender},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
pub use error::NetError;
|
||||
use reactor::{Destination, RemotePeer, Shared};
|
||||
pub use reactor::{Reliability, Settings};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DATAGRAM_MAX_LEN: usize = 30000; // TODO this probably should be 1500
|
||||
|
||||
/// Maximum size of a message which fits into a single datagram.
|
||||
pub const MAX_MESSAGE_LEN: usize = DATAGRAM_MAX_LEN - 100;
|
||||
|
||||
mod common;
|
||||
mod connection_manager;
|
||||
mod error;
|
||||
mod reactor;
|
||||
mod util;
|
||||
mod helpers;
|
||||
|
||||
struct Datagram {
|
||||
pub size: usize,
|
||||
pub data: [u8; DATAGRAM_MAX_LEN],
|
||||
}
|
||||
|
||||
/// A value which refers to a specific peer.
|
||||
/// Peer 0 is always the host.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct PeerId(pub u16);
|
||||
|
||||
impl PeerId {
|
||||
pub const HOST: PeerId = PeerId(0);
|
||||
}
|
||||
|
||||
impl Display for PeerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
type SeqId = u16;
|
||||
|
||||
/// Possible network events, returned by `Peer.recv()`.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum NetworkEvent {
|
||||
/// A new peer has connected.
|
||||
PeerConnected(PeerId),
|
||||
/// Peer has disconnected.
|
||||
PeerDisconnected(PeerId),
|
||||
/// Message has been received.
|
||||
Message(Message),
|
||||
}
|
||||
|
||||
/// A message received from a peer.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct Message {
|
||||
/// Original peer who sent the message.
|
||||
pub src: PeerId,
|
||||
/// The data that has been sent.
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
struct OutboundMessage {
|
||||
pub dst: Destination,
|
||||
pub data: Vec<u8>,
|
||||
pub reliability: Reliability,
|
||||
}
|
||||
|
||||
/// Current peer state
|
||||
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum PeerState {
|
||||
/// Waiting for connection. Switches to `Connected` right after id from the host has been acquired.
|
||||
/// Note: hosts switches to 'Connected' basically instantly.
|
||||
#[default]
|
||||
PendingConnection,
|
||||
/// Connected to host and ready to send/receive messages.
|
||||
Connected,
|
||||
/// No longer connected, won't reconnect.
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl Display for PeerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PeerState::PendingConnection => write!(f, "Connection pending..."),
|
||||
PeerState::Connected => write!(f, "Connected"),
|
||||
PeerState::Disconnected => write!(f, "Disconnected"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Channel<T> = (Sender<T>, Receiver<T>);
|
||||
pub use common::*;
|
||||
|
||||
/// Represents a network endpoint. Can be constructed in either `host` or `client` mode.
|
||||
///
|
||||
|
@ -113,24 +45,8 @@ impl Peer {
|
|||
bind_addr: SocketAddr,
|
||||
host_addr: Option<SocketAddr>,
|
||||
settings: Option<Settings>,
|
||||
) -> io::Result<Self> {
|
||||
let socket = UdpSocket::bind(bind_addr)?;
|
||||
let shared = Arc::new(Shared {
|
||||
socket,
|
||||
inbound_channel: unbounded(),
|
||||
outbound_channel: unbounded(),
|
||||
keep_alive: AtomicBool::new(true),
|
||||
host_addr,
|
||||
peer_state: Default::default(),
|
||||
remote_peers: Default::default(),
|
||||
max_packets_per_second: 256,
|
||||
my_id: AtomicCell::new(if host_addr.is_none() {
|
||||
Some(PeerId(0))
|
||||
} else {
|
||||
None
|
||||
}),
|
||||
settings: settings.unwrap_or_default(),
|
||||
});
|
||||
) -> Result<Self, TangledInitError> {
|
||||
let shared = Arc::new(Shared::new(host_addr, settings));
|
||||
if host_addr.is_none() {
|
||||
shared.remote_peers.insert(PeerId(0), RemotePeer::default());
|
||||
shared
|
||||
|
@ -139,17 +55,23 @@ impl Peer {
|
|||
.send(NetworkEvent::PeerConnected(PeerId(0)))
|
||||
.unwrap();
|
||||
}
|
||||
reactor::Reactor::start(Arc::clone(&shared));
|
||||
ConnectionManager::new(Arc::clone(&shared), bind_addr)?.start()?;
|
||||
Ok(Peer { shared })
|
||||
}
|
||||
|
||||
/// Host at a specified `bind_addr`.
|
||||
pub fn host(bind_addr: SocketAddr, settings: Option<Settings>) -> io::Result<Self> {
|
||||
pub fn host(
|
||||
bind_addr: SocketAddr,
|
||||
settings: Option<Settings>,
|
||||
) -> Result<Self, TangledInitError> {
|
||||
Self::new(bind_addr, None, settings)
|
||||
}
|
||||
|
||||
/// Connect to a specified `host_addr`.
|
||||
pub fn connect(host_addr: SocketAddr, settings: Option<Settings>) -> io::Result<Self> {
|
||||
pub fn connect(
|
||||
host_addr: SocketAddr,
|
||||
settings: Option<Settings>,
|
||||
) -> Result<Self, TangledInitError> {
|
||||
Self::new("0.0.0.0:0".parse().unwrap(), Some(host_addr), settings)
|
||||
}
|
||||
|
||||
|
@ -176,11 +98,6 @@ impl Peer {
|
|||
if data.len() > MAX_MESSAGE_LEN {
|
||||
return Err(NetError::MessageTooLong);
|
||||
}
|
||||
if reliability == Reliability::Unreliable
|
||||
&& self.shared.outbound_channel.0.len() * 2 > self.shared.max_packets_per_second
|
||||
{
|
||||
return Err(NetError::Dropped);
|
||||
}
|
||||
self.shared.outbound_channel.0.send(OutboundMessage {
|
||||
dst: destination,
|
||||
data,
|
||||
|
@ -233,7 +150,7 @@ impl Drop for Peer {
|
|||
mod test {
|
||||
use std::{thread, time::Duration};
|
||||
|
||||
use crate::{reactor::Settings, Message, NetworkEvent, Peer, PeerId, Reliability};
|
||||
use crate::{common::Message, NetworkEvent, Peer, PeerId, Reliability, Settings};
|
||||
|
||||
#[test_log::test]
|
||||
fn test_peer() {
|
||||
|
|
|
@ -1,679 +0,0 @@
|
|||
use crate::{
|
||||
error::NetError,
|
||||
util::{RateLimiter, RingSet},
|
||||
Channel, Message, NetworkEvent, OutboundMessage, PeerId, SeqId,
|
||||
};
|
||||
|
||||
use super::{Datagram, PeerState, DATAGRAM_MAX_LEN};
|
||||
use crossbeam::{
|
||||
atomic::AtomicCell,
|
||||
channel::{bounded, Receiver, Sender},
|
||||
select,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
error::Error,
|
||||
io::Cursor,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU16, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
/// Per-peer settings. Peers that are connected to the same host, as well as the host itself, should have the same settings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Settings {
|
||||
/// A single datagram will confirm at most this much messages. Default is 128.
|
||||
pub confirm_max_per_message: usize,
|
||||
/// How much time can elapse before another confirm is sent.
|
||||
/// Confirms are also sent when enough messages are awaiting confirm.
|
||||
/// Note that confirms also double as "heartbeats" and keep the connection alive, so this value should be much less than `connection_timeout`.
|
||||
/// Default: 1 second.
|
||||
pub confirm_max_period: Duration,
|
||||
/// Peers will be disconnected after this much time without any datagrams from them has passed.
|
||||
/// Default: 10 seconds.
|
||||
pub connection_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
confirm_max_per_message: 128,
|
||||
confirm_max_period: Duration::from_secs(1),
|
||||
connection_timeout: Duration::from_secs(10),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Shared {
|
||||
pub settings: Settings,
|
||||
pub socket: UdpSocket,
|
||||
pub inbound_channel: Channel<NetworkEvent>,
|
||||
pub outbound_channel: Channel<OutboundMessage>,
|
||||
pub keep_alive: AtomicBool,
|
||||
pub peer_state: AtomicCell<PeerState>,
|
||||
pub remote_peers: DashMap<PeerId, RemotePeer>,
|
||||
pub max_packets_per_second: usize,
|
||||
pub host_addr: Option<SocketAddr>,
|
||||
pub my_id: AtomicCell<Option<PeerId>>,
|
||||
}
|
||||
|
||||
struct DirectPeer {
|
||||
addr: SocketAddr,
|
||||
outbound_pending: VecDeque<NetMessageVariant>,
|
||||
resend_pending: VecDeque<(Instant, NetMessageNormal)>,
|
||||
confirmed: RingSet<SeqId>,
|
||||
rate_limit: RateLimiter,
|
||||
seq_counter: AtomicU16,
|
||||
recent_seq: RingSet<SeqId>,
|
||||
pending_confirms: VecDeque<SeqId>,
|
||||
last_confirm_sent: Instant,
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct RemotePeer {}
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
pub enum Destination {
|
||||
One(PeerId),
|
||||
Broadcast,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
enum NetMessageVariant {
|
||||
Login,
|
||||
Normal(NetMessageNormal),
|
||||
}
|
||||
|
||||
/// Tells how reliable a message is.
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Debug)]
|
||||
pub enum Reliability {
|
||||
/// Message will be delivered at most once.
|
||||
Unreliable,
|
||||
/// Message will be resent untill is's arrival will be confirmed.
|
||||
/// Will be delivered at most once.
|
||||
Reliable,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
struct NetMessageNormal {
|
||||
// Source that generated sequence id.
|
||||
// Initially the same as origin_src, but can be changed when packet is retransmitted not as-is, e. g. when it is broadcasted.
|
||||
src: PeerId,
|
||||
// Original source.
|
||||
origin_src: PeerId,
|
||||
dst: Destination,
|
||||
seq_id: SeqId,
|
||||
reliability: Reliability,
|
||||
inner: NetMessageInner,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
enum NetMessageInner {
|
||||
RegDone { addr: SocketAddr },
|
||||
AddPeer { id: PeerId },
|
||||
DelPeer { id: PeerId },
|
||||
Confirm { confirmed_ids: Vec<SeqId> },
|
||||
Payload { data: Vec<u8> },
|
||||
}
|
||||
|
||||
impl TryFrom<Datagram> for NetMessageVariant {
|
||||
type Error = bincode::Error;
|
||||
|
||||
fn try_from(datagram: Datagram) -> Result<Self, Self::Error> {
|
||||
bincode::deserialize(&datagram.data[..datagram.size])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&NetMessageVariant> for Datagram {
|
||||
type Error = bincode::Error;
|
||||
|
||||
fn try_from(value: &NetMessageVariant) -> Result<Self, Self::Error> {
|
||||
let mut data = Cursor::new([0; DATAGRAM_MAX_LEN]);
|
||||
bincode::serialize_into(&mut data, value)?;
|
||||
let size = data.position().try_into().unwrap();
|
||||
let data = data.into_inner();
|
||||
Ok(Datagram { data, size })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Reactor {
|
||||
shared: Arc<Shared>,
|
||||
direct_peers: HashMap<PeerId, DirectPeer>,
|
||||
}
|
||||
|
||||
type AddrDatagram = (SocketAddr, Datagram);
|
||||
|
||||
impl Reactor {
|
||||
fn add_peer(&self, id: PeerId) -> Result<(), NetError> {
|
||||
self.shared.remote_peers.insert(id, RemotePeer::default());
|
||||
self.shared
|
||||
.inbound_channel
|
||||
.0
|
||||
.send(NetworkEvent::PeerConnected(id))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn direct_broadcast(
|
||||
&mut self,
|
||||
src_id: PeerId,
|
||||
msg: NetMessageInner,
|
||||
reliability: Reliability,
|
||||
) -> Result<(), NetError> {
|
||||
for (&peer_id, peer) in self.direct_peers.iter_mut() {
|
||||
let new_seq_id = peer.seq_counter.fetch_add(1, SeqCst);
|
||||
let new_msg = Self::wrap_packet_seq_id(
|
||||
src_id,
|
||||
src_id,
|
||||
new_seq_id,
|
||||
Destination::One(peer_id),
|
||||
msg.clone(),
|
||||
reliability,
|
||||
)?;
|
||||
Self::direct_send_peer(peer, new_msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn direct_send(&mut self, id: PeerId, msg: NetMessageVariant) -> Result<(), NetError> {
|
||||
let peer = self
|
||||
.direct_peers
|
||||
.get_mut(&id)
|
||||
.ok_or(NetError::UnknownPeer)?;
|
||||
Self::direct_send_peer(peer, msg)
|
||||
}
|
||||
|
||||
fn direct_send_peer(peer: &mut DirectPeer, msg: NetMessageVariant) -> Result<(), NetError> {
|
||||
peer.outbound_pending.push_back(msg);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn gen_peer_id(&mut self) -> Option<PeerId> {
|
||||
(1..=u16::MAX)
|
||||
.map(PeerId)
|
||||
.find(|i| !self.shared.remote_peers.contains_key(i))
|
||||
}
|
||||
|
||||
fn handle_inbound(&mut self, (incoming_addr, msg_raw): AddrDatagram) {
|
||||
let msg = match NetMessageVariant::try_from(msg_raw) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
warn!("Error when converting to NetMessage: {}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
match self.shared.my_id.load() {
|
||||
Some(id) => {
|
||||
match msg {
|
||||
NetMessageVariant::Login => {
|
||||
if self.is_host() {
|
||||
//TODO check this addr is not already registered
|
||||
match self.gen_peer_id() {
|
||||
Some(new_id) => {
|
||||
self.add_peer(new_id).ok();
|
||||
let mut peer = DirectPeer::new(
|
||||
incoming_addr,
|
||||
self.shared.max_packets_per_second,
|
||||
);
|
||||
peer.outbound_pending.push_back(NetMessageVariant::Normal(
|
||||
NetMessageNormal {
|
||||
src: id,
|
||||
origin_src: id,
|
||||
dst: Destination::One(new_id),
|
||||
seq_id: u16::MAX,
|
||||
inner: NetMessageInner::RegDone {
|
||||
addr: incoming_addr,
|
||||
},
|
||||
reliability: Reliability::Reliable,
|
||||
},
|
||||
));
|
||||
self.direct_peers.insert(new_id, peer);
|
||||
self.direct_broadcast(
|
||||
id,
|
||||
NetMessageInner::AddPeer { id: new_id },
|
||||
Reliability::Reliable,
|
||||
)
|
||||
.ok();
|
||||
let shared = self.shared.clone();
|
||||
for re in shared.remote_peers.iter() {
|
||||
let id = *re.key();
|
||||
if id != new_id {
|
||||
self.wrap_packet(
|
||||
id,
|
||||
Destination::One(new_id),
|
||||
NetMessageInner::AddPeer { id },
|
||||
Reliability::Reliable,
|
||||
)
|
||||
.and_then(|msg| self.direct_send(new_id, msg))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
None => warn!("Out of ids"),
|
||||
}
|
||||
} else {
|
||||
warn!("Not a host, registration attempt ignored");
|
||||
}
|
||||
}
|
||||
NetMessageVariant::Normal(msg) => {
|
||||
match self.handle_inbound_normal(msg, incoming_addr, id) {
|
||||
Ok(_) => {}
|
||||
Err(NetError::Dropped) => {}
|
||||
Err(err) => {
|
||||
info!("Error while handling normal inbound message: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => match msg {
|
||||
NetMessageVariant::Normal(NetMessageNormal {
|
||||
inner: NetMessageInner::RegDone { addr: _ },
|
||||
dst,
|
||||
src,
|
||||
..
|
||||
}) => {
|
||||
let expected_host_addr = self
|
||||
.shared
|
||||
.host_addr
|
||||
.expect("Can't have both my_id and host_addr be None");
|
||||
if incoming_addr == expected_host_addr && src == PeerId(0) {
|
||||
if let Destination::One(id) = dst {
|
||||
self.shared.my_id.store(Some(id));
|
||||
self.add_peer(PeerId(0)).ok();
|
||||
self.shared.peer_state.store(PeerState::Connected);
|
||||
} else {
|
||||
warn!("Malformed registration message");
|
||||
}
|
||||
} else {
|
||||
warn!("Registration message recieved not from the right address ({}, {} expected)", incoming_addr, expected_host_addr);
|
||||
}
|
||||
}
|
||||
_ => warn!("Message ignored as registration is not done yet"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_inbound_normal(
|
||||
&mut self,
|
||||
msg: NetMessageNormal,
|
||||
_incoming_addr: SocketAddr,
|
||||
my_id: PeerId,
|
||||
) -> Result<(), NetError> {
|
||||
let peer = self.direct_peers.get_mut(&msg.src);
|
||||
if peer
|
||||
.as_ref()
|
||||
.map_or(true, |peer| peer.recent_seq.contains(&msg.seq_id))
|
||||
{
|
||||
return Err(NetError::Dropped);
|
||||
}
|
||||
{
|
||||
let peer = peer.expect("Expected to exist");
|
||||
peer.recent_seq.add(msg.seq_id); //TODO backpressure
|
||||
peer.pending_confirms.push_back(msg.seq_id);
|
||||
peer.last_seen = Instant::now()
|
||||
}
|
||||
|
||||
if Destination::One(my_id) == msg.dst || msg.dst == Destination::Broadcast {
|
||||
// TODO eliminate this clone
|
||||
match msg.inner.clone() {
|
||||
NetMessageInner::RegDone { addr: _ } => {
|
||||
warn!("Already registered, request ignored");
|
||||
}
|
||||
NetMessageInner::AddPeer { id } => {
|
||||
if !self.is_host() {
|
||||
self.add_peer(id).ok();
|
||||
info!("Peer {} added", id);
|
||||
}
|
||||
}
|
||||
NetMessageInner::DelPeer { id } => {
|
||||
if !self.is_host() {
|
||||
self.del_peer(id).ok();
|
||||
info!("Peer {} removed", id);
|
||||
}
|
||||
}
|
||||
NetMessageInner::Confirm { confirmed_ids } => {
|
||||
if let Some(peer) = self.direct_peers.get_mut(&msg.src) {
|
||||
for id in confirmed_ids {
|
||||
peer.confirmed.add(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
NetMessageInner::Payload { data } => {
|
||||
self.shared
|
||||
.inbound_channel
|
||||
.0
|
||||
.send(NetworkEvent::Message(Message {
|
||||
src: msg.origin_src,
|
||||
data,
|
||||
}))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.is_host() && Destination::One(my_id) != msg.dst {
|
||||
match msg.dst {
|
||||
Destination::One(dst) => {
|
||||
let new_msg =
|
||||
self.wrap_packet(dst, Destination::One(dst), msg.inner, msg.reliability)?;
|
||||
self.direct_send(dst, new_msg)?;
|
||||
}
|
||||
Destination::Broadcast => {
|
||||
let mut buf = Vec::new();
|
||||
for peer in &self.direct_peers {
|
||||
if *peer.0 == msg.src {
|
||||
continue;
|
||||
}
|
||||
let seq_id = self.next_seq_id_for_peer(*peer.0)?;
|
||||
if let Ok(wrapped_msg) = Self::wrap_packet_seq_id(
|
||||
PeerId(0),
|
||||
msg.origin_src,
|
||||
seq_id,
|
||||
Destination::One(*peer.0),
|
||||
msg.inner.clone(),
|
||||
msg.reliability,
|
||||
) {
|
||||
buf.push((*peer.0, wrapped_msg));
|
||||
}
|
||||
}
|
||||
for (peer_id, wrapped_msg) in buf {
|
||||
self.direct_send(peer_id, wrapped_msg).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn del_peer(&mut self, id: PeerId) -> Result<(), NetError> {
|
||||
self.shared.remote_peers.remove(&id);
|
||||
self.shared
|
||||
.inbound_channel
|
||||
.0
|
||||
.send(NetworkEvent::PeerDisconnected(id))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_outbound(&mut self, msg: OutboundMessage) -> Result<(), NetError> {
|
||||
let dst = msg.dst;
|
||||
if self.is_host() {
|
||||
match dst {
|
||||
Destination::One(id) => {
|
||||
let net_msg = self.wrap_packet(
|
||||
id,
|
||||
dst,
|
||||
NetMessageInner::Payload { data: msg.data },
|
||||
msg.reliability,
|
||||
)?;
|
||||
self.direct_send(id, net_msg)?;
|
||||
}
|
||||
Destination::Broadcast => self.direct_broadcast(
|
||||
PeerId(0),
|
||||
NetMessageInner::Payload { data: msg.data },
|
||||
msg.reliability,
|
||||
)?,
|
||||
}
|
||||
} else {
|
||||
let net_msg = self.wrap_packet(
|
||||
PeerId(0),
|
||||
dst,
|
||||
NetMessageInner::Payload { data: msg.data },
|
||||
msg.reliability,
|
||||
)?;
|
||||
self.direct_send(PeerId(0), net_msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_host(&self) -> bool {
|
||||
self.shared.host_addr.is_none()
|
||||
}
|
||||
|
||||
pub fn next_seq_id_for_peer(&self, peer_id: PeerId) -> Result<SeqId, NetError> {
|
||||
Ok(self
|
||||
.direct_peers
|
||||
.get(&peer_id)
|
||||
.or_else(|| {
|
||||
if !self.is_host() {
|
||||
self.direct_peers.get(&PeerId(0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(NetError::UnknownPeer)?
|
||||
.seq_counter
|
||||
.fetch_add(1, SeqCst))
|
||||
}
|
||||
|
||||
fn run(mut self, inbound_r: Receiver<AddrDatagram>) -> Result<(), Box<dyn Error>> {
|
||||
while self.shared.keep_alive.load(SeqCst) {
|
||||
select! {
|
||||
recv(inbound_r) -> addr_msg => self.handle_inbound(addr_msg?),
|
||||
recv(self.shared.outbound_channel.1) -> msg => {self.handle_outbound(msg?).ok();}
|
||||
default => {thread::sleep(Duration::from_micros(100));}
|
||||
}
|
||||
let mut dc = Vec::new();
|
||||
self.direct_peers.retain(|&k, v| {
|
||||
let stays = v.last_seen.elapsed() < self.shared.settings.connection_timeout;
|
||||
if !stays {
|
||||
dc.push(k);
|
||||
}
|
||||
stays
|
||||
});
|
||||
if self.is_host() {
|
||||
for peer_id in dc {
|
||||
let src_id = self.shared.my_id.load().unwrap(); // Should always be PeerId(0)
|
||||
assert_eq!(src_id, PeerId(0));
|
||||
self.direct_broadcast(
|
||||
src_id,
|
||||
NetMessageInner::DelPeer { id: peer_id },
|
||||
Reliability::Reliable,
|
||||
)?;
|
||||
self.del_peer(peer_id).ok();
|
||||
info!("[Host] Peer {} removed", peer_id);
|
||||
}
|
||||
}
|
||||
if !self.is_host() && self.direct_peers.is_empty() {
|
||||
self.shared.peer_state.store(PeerState::Disconnected);
|
||||
self.shared.keep_alive.store(false, SeqCst);
|
||||
}
|
||||
'peers: for (&id, peer) in self.direct_peers.iter_mut() {
|
||||
let resend_in = Instant::now() + Duration::from_secs(1);
|
||||
|
||||
if let Some(my_id) = self.shared.my_id.load() {
|
||||
if peer.last_confirm_sent.elapsed() > self.shared.settings.confirm_max_period
|
||||
|| peer.pending_confirms.len()
|
||||
> self.shared.settings.confirm_max_per_message
|
||||
{
|
||||
peer.last_confirm_sent = Instant::now();
|
||||
let max_per_message = self.shared.settings.confirm_max_per_message;
|
||||
let mut confirmed_ids = Vec::with_capacity(max_per_message);
|
||||
while let Some(confirm) = peer.pending_confirms.pop_front() {
|
||||
confirmed_ids.push(confirm);
|
||||
if confirmed_ids.len() == max_per_message {
|
||||
break;
|
||||
}
|
||||
}
|
||||
peer.resend_pending.push_front((
|
||||
Instant::now(),
|
||||
NetMessageNormal {
|
||||
src: my_id,
|
||||
origin_src: my_id,
|
||||
dst: Destination::One(id),
|
||||
seq_id: peer.seq_counter.fetch_add(1, SeqCst),
|
||||
reliability: Reliability::Reliable,
|
||||
inner: NetMessageInner::Confirm { confirmed_ids },
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
while peer
|
||||
.resend_pending
|
||||
.front()
|
||||
.map_or(false, |x| x.0 < Instant::now())
|
||||
{
|
||||
let (moment, msg) = peer
|
||||
.resend_pending
|
||||
.pop_front()
|
||||
.expect("Checked that deque is not empty");
|
||||
|
||||
if !peer.confirmed.contains(&msg.seq_id) {
|
||||
if !peer.rate_limit.get_token() {
|
||||
peer.resend_pending.push_front((moment, msg));
|
||||
continue 'peers;
|
||||
}
|
||||
peer.resend_pending.push_back((resend_in, msg.clone()));
|
||||
trace!("Sent {:?} to {}", msg, peer.addr,);
|
||||
let datagram = Datagram::try_from(&NetMessageVariant::Normal(msg)).unwrap();
|
||||
trace!("size: {}", datagram.size);
|
||||
self.shared
|
||||
.socket
|
||||
.send_to(&datagram.data[..datagram.size], peer.addr)
|
||||
.expect("Could not send");
|
||||
}
|
||||
}
|
||||
|
||||
while !peer.outbound_pending.is_empty() && peer.rate_limit.get_token() {
|
||||
let msg = peer
|
||||
.outbound_pending
|
||||
.pop_front()
|
||||
.expect("Checked that deque is not empty");
|
||||
if let NetMessageVariant::Normal(ref msg) = msg {
|
||||
if msg.reliability == Reliability::Reliable {
|
||||
peer.resend_pending.push_back((resend_in, msg.clone()));
|
||||
}
|
||||
}
|
||||
let datagram = Datagram::try_from(&msg).unwrap();
|
||||
trace!("sent msg size: {}", datagram.size);
|
||||
self.shared
|
||||
.socket
|
||||
.send_to(&datagram.data[..datagram.size], peer.addr)
|
||||
.expect("Could not send");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_pipe(
|
||||
shared: Arc<Shared>,
|
||||
sender: Sender<(SocketAddr, Datagram)>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
while shared.keep_alive.load(SeqCst) {
|
||||
let mut buf = [0u8; DATAGRAM_MAX_LEN];
|
||||
match shared.socket.recv_from(&mut buf) {
|
||||
Ok((len, addr)) => sender
|
||||
.send((
|
||||
addr,
|
||||
Datagram {
|
||||
size: len,
|
||||
data: buf,
|
||||
},
|
||||
))
|
||||
.map_err(Box::new)?,
|
||||
//Err(err)
|
||||
// if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut => {
|
||||
//}
|
||||
Err(err) => return Err(Box::new(err)),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start(shared: Arc<Shared>) {
|
||||
let mut me = Reactor {
|
||||
shared,
|
||||
direct_peers: Default::default(),
|
||||
};
|
||||
if !me.is_host() {
|
||||
me.direct_peers.insert(
|
||||
PeerId(0),
|
||||
DirectPeer::new(
|
||||
me.shared
|
||||
.host_addr
|
||||
.expect("Can't be a client without a host addr"),
|
||||
me.shared.max_packets_per_second,
|
||||
),
|
||||
);
|
||||
me.direct_send(PeerId(0), NetMessageVariant::Login).unwrap();
|
||||
}
|
||||
if me.is_host() {
|
||||
me.shared.peer_state.store(PeerState::Connected);
|
||||
}
|
||||
let shared_c = Arc::clone(&me.shared);
|
||||
let (inbound_s, inbound_r) = bounded(16);
|
||||
thread::spawn(move || {
|
||||
let shared_c_2 = Arc::clone(&shared_c);
|
||||
if let Err(err) = Self::run_pipe(shared_c_2, inbound_s) {
|
||||
shared_c.keep_alive.store(false, SeqCst);
|
||||
shared_c.peer_state.store(PeerState::Disconnected);
|
||||
error!("Reactor pipe error: {}", err);
|
||||
}
|
||||
});
|
||||
let shared_c = Arc::clone(&me.shared);
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = me.run(inbound_r) {
|
||||
shared_c.keep_alive.store(false, SeqCst);
|
||||
shared_c.peer_state.store(PeerState::Disconnected);
|
||||
error!("Reactor error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn wrap_packet(
|
||||
&self,
|
||||
id: PeerId,
|
||||
dst: Destination,
|
||||
msg: NetMessageInner,
|
||||
reliability: Reliability,
|
||||
) -> Result<NetMessageVariant, NetError> {
|
||||
let seq_id = self.next_seq_id_for_peer(id)?;
|
||||
let src = self.shared.my_id.load().expect("Should know own id by now");
|
||||
Self::wrap_packet_seq_id(src, src, seq_id, dst, msg, reliability)
|
||||
}
|
||||
|
||||
fn wrap_packet_seq_id(
|
||||
src: PeerId,
|
||||
origin_src: PeerId,
|
||||
seq_id: SeqId,
|
||||
dst: Destination,
|
||||
msg: NetMessageInner,
|
||||
reliability: Reliability,
|
||||
) -> Result<NetMessageVariant, NetError> {
|
||||
Ok(NetMessageVariant::Normal(NetMessageNormal {
|
||||
src,
|
||||
origin_src,
|
||||
dst,
|
||||
seq_id,
|
||||
inner: msg,
|
||||
reliability,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectPeer {
|
||||
fn new(incoming_addr: SocketAddr, rate_limit: usize) -> DirectPeer {
|
||||
let now = Instant::now();
|
||||
DirectPeer {
|
||||
addr: incoming_addr,
|
||||
outbound_pending: Default::default(),
|
||||
resend_pending: Default::default(),
|
||||
confirmed: RingSet::new(1024),
|
||||
rate_limit: RateLimiter::new(rate_limit, Duration::from_secs(1)),
|
||||
seq_counter: AtomicU16::new(0),
|
||||
recent_seq: RingSet::new(1024),
|
||||
pending_confirms: VecDeque::new(),
|
||||
last_confirm_sent: now,
|
||||
last_seen: now,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
use std::{
|
||||
collections::{HashSet, VecDeque},
|
||||
hash::Hash,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
pub struct RateLimiter {
|
||||
moments: VecDeque<Instant>,
|
||||
time: Duration,
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(limit: usize, time: Duration) -> Self {
|
||||
Self {
|
||||
moments: VecDeque::with_capacity(limit),
|
||||
time,
|
||||
limit,
|
||||
}
|
||||
}
|
||||
pub fn get_token(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
while self
|
||||
.moments
|
||||
.front()
|
||||
.map_or(false, |moment| now - *moment > self.time)
|
||||
{
|
||||
self.moments.pop_front();
|
||||
}
|
||||
if self.moments.len() < self.limit {
|
||||
self.moments.push_back(now);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingSet<Key: Hash + Eq + Clone> {
|
||||
set: HashSet<Key>,
|
||||
ring: VecDeque<Key>,
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
impl<Key: Hash + Eq + Clone> RingSet<Key> {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
assert!(limit > 0);
|
||||
Self {
|
||||
set: HashSet::new(),
|
||||
ring: VecDeque::with_capacity(limit),
|
||||
limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, key: Key) {
|
||||
if !self.contains(&key) {
|
||||
if self.ring.len() >= self.limit {
|
||||
let element = self.ring.pop_front().expect("Deque has elements");
|
||||
self.set.remove(&element);
|
||||
}
|
||||
self.set.insert(key.clone());
|
||||
self.ring.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains(&self, key: &Key) -> bool {
|
||||
self.set.contains(key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{thread, time::Duration};
|
||||
|
||||
use super::{RateLimiter, RingSet};
|
||||
|
||||
#[test]
|
||||
fn rate_limit() {
|
||||
let duration = Duration::from_micros(100);
|
||||
let mut limiter = RateLimiter::new(4, duration);
|
||||
|
||||
for _ in 0..4 {
|
||||
assert!(limiter.get_token())
|
||||
}
|
||||
assert!(!limiter.get_token());
|
||||
thread::sleep(duration * 2);
|
||||
assert!(limiter.get_token());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_set() {
|
||||
let mut set = RingSet::new(3);
|
||||
set.add(1);
|
||||
assert!(set.contains(&1));
|
||||
set.add(2);
|
||||
assert!(set.contains(&1));
|
||||
assert!(set.contains(&2));
|
||||
assert!(!set.contains(&3));
|
||||
set.add(3);
|
||||
set.add(3);
|
||||
set.add(4);
|
||||
assert!(!set.contains(&1));
|
||||
assert!(set.contains(&4));
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue