Make most tests pass

This commit is contained in:
IQuant 2024-09-06 20:39:24 +03:00
parent 868cc1ffaa
commit f4b5eef5fb
3 changed files with 14 additions and 21 deletions

View file

@ -78,7 +78,6 @@ impl DirectPeer {
shared shared
.internal_events_s .internal_events_s
.send(InternalEvent::Disconnected(remote_id)) .send(InternalEvent::Disconnected(remote_id))
.await
.ok(); .ok();
} }
@ -126,7 +125,7 @@ impl DirectPeer {
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?; .map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
debug!("Got peer id {peer_id}"); debug!("Got peer id {peer_id}");
let (send_stream, recv_stream) = connection.open_bi().await?; let (send_stream, recv_stream) = connection.accept_bi().await?;
tokio::spawn(Self::recv_task(shared, recv_stream, PeerId::HOST)); tokio::spawn(Self::recv_task(shared, recv_stream, PeerId::HOST));
debug!("Client: spawned recv task"); debug!("Client: spawned recv task");
@ -163,7 +162,7 @@ enum InternalEvent {
pub(crate) struct Shared { pub(crate) struct Shared {
pub inbound_channel: Channel<NetworkEvent>, pub inbound_channel: Channel<NetworkEvent>,
pub outbound_messages_s: tokio::sync::mpsc::Sender<OutboundMessage>, pub outbound_messages_s: tokio::sync::mpsc::UnboundedSender<OutboundMessage>,
pub keep_alive: AtomicBool, pub keep_alive: AtomicBool,
pub peer_state: AtomicCell<PeerState>, pub peer_state: AtomicCell<PeerState>,
pub remote_peers: DashMap<PeerId, RemotePeer>, pub remote_peers: DashMap<PeerId, RemotePeer>,
@ -172,7 +171,7 @@ pub(crate) struct Shared {
// ConnectionManager-specific stuff // ConnectionManager-specific stuff
direct_peers: DashMap<PeerId, DirectPeer>, direct_peers: DashMap<PeerId, DirectPeer>,
internal_incoming_messages_s: tokio::sync::mpsc::Sender<(PeerId, InternalMessage)>, internal_incoming_messages_s: tokio::sync::mpsc::Sender<(PeerId, InternalMessage)>,
internal_events_s: tokio::sync::mpsc::Sender<InternalEvent>, internal_events_s: tokio::sync::mpsc::UnboundedSender<InternalEvent>,
} }
pub(crate) struct ConnectionManager { pub(crate) struct ConnectionManager {
@ -181,8 +180,8 @@ pub(crate) struct ConnectionManager {
host_conn: Option<DirectPeer>, host_conn: Option<DirectPeer>,
is_server: bool, is_server: bool,
incoming_messages_r: tokio::sync::mpsc::Receiver<(PeerId, InternalMessage)>, incoming_messages_r: tokio::sync::mpsc::Receiver<(PeerId, InternalMessage)>,
outbound_messages_r: tokio::sync::mpsc::Receiver<OutboundMessage>, outbound_messages_r: tokio::sync::mpsc::UnboundedReceiver<OutboundMessage>,
internal_events_r: tokio::sync::mpsc::Receiver<InternalEvent>, internal_events_r: tokio::sync::mpsc::UnboundedReceiver<InternalEvent>,
} }
impl ConnectionManager { impl ConnectionManager {
@ -194,8 +193,8 @@ impl ConnectionManager {
let is_server = host_addr.is_none(); let is_server = host_addr.is_none();
let (internal_incoming_messages_s, incoming_messages_r) = tokio::sync::mpsc::channel(512); let (internal_incoming_messages_s, incoming_messages_r) = tokio::sync::mpsc::channel(512);
let (outbound_messages_s, outbound_messages_r) = tokio::sync::mpsc::channel(512); let (outbound_messages_s, outbound_messages_r) = tokio::sync::mpsc::unbounded_channel();
let (internal_events_s, internal_events_r) = tokio::sync::mpsc::channel(512); let (internal_events_s, internal_events_r) = tokio::sync::mpsc::unbounded_channel();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
inbound_channel: unbounded(), inbound_channel: unbounded(),
@ -254,7 +253,6 @@ impl ConnectionManager {
shared shared
.internal_events_s .internal_events_s
.send(InternalEvent::Connected(PeerId(peer_id_counter))) .send(InternalEvent::Connected(PeerId(peer_id_counter)))
.await
.expect("channel to be open"); .expect("channel to be open");
peer_id_counter += 1; peer_id_counter += 1;
} }
@ -280,20 +278,17 @@ impl ConnectionManager {
})) }))
.expect("channel to be open"); .expect("channel to be open");
} }
// TODO this might deadlock if internal_events_s is full.
InternalMessage::RemoteConnected(peer_id) => { InternalMessage::RemoteConnected(peer_id) => {
debug!("Got notified of peer {peer_id}"); debug!("Got notified of peer {peer_id}");
self.shared self.shared
.internal_events_s .internal_events_s
.send(InternalEvent::Connected(peer_id)) .send(InternalEvent::Connected(peer_id))
.await
.expect("channel to be open"); .expect("channel to be open");
} }
InternalMessage::RemoteDisconnected(peer_id) => self InternalMessage::RemoteDisconnected(peer_id) => self
.shared .shared
.internal_events_s .internal_events_s
.send(InternalEvent::Disconnected(peer_id)) .send(InternalEvent::Disconnected(peer_id))
.await
.expect("channel to be open"), .expect("channel to be open"),
} }
} }
@ -399,7 +394,6 @@ impl ConnectionManager {
self.shared self.shared
.internal_events_s .internal_events_s
.send(InternalEvent::Connected(host_conn.remote_id)) .send(InternalEvent::Connected(host_conn.remote_id))
.await
.expect("channel to be open"); .expect("channel to be open");
self.host_conn = Some(host_conn); self.host_conn = Some(host_conn);
} }

View file

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use bitcode::{DecodeOwned, Encode}; use bitcode::{DecodeOwned, Encode};
use quinn::{RecvStream, SendStream}; use quinn::{RecvStream, SendStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace}; use tracing::trace;
use super::DirectConnectionError; use super::DirectConnectionError;
@ -42,7 +42,6 @@ impl<Msg: Encode> SendMessageStream<Msg> {
} }
pub(crate) async fn send(&mut self, msg: &Msg) -> Result<(), DirectConnectionError> { pub(crate) async fn send(&mut self, msg: &Msg) -> Result<(), DirectConnectionError> {
trace!("Sending message");
let msg = bitcode::encode(msg); let msg = bitcode::encode(msg);
self.send_raw(&msg).await self.send_raw(&msg).await
} }
@ -62,7 +61,7 @@ impl<Msg: DecodeOwned> RecvMessageStream<Msg> {
.read_u32() .read_u32()
.await .await
.map_err(|_err| DirectConnectionError::MessageIoFailed)?; .map_err(|_err| DirectConnectionError::MessageIoFailed)?;
trace!("Expecting message of {len}"); trace!("Expecting message of len {len}");
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
self.inner self.inner
.read_exact(&mut buf) .read_exact(&mut buf)

View file

@ -91,7 +91,7 @@ impl Peer {
} }
self.shared self.shared
.outbound_messages_s .outbound_messages_s
.blocking_send(OutboundMessage { .send(OutboundMessage {
src: self.my_id().expect("expected to know my_id by this point"), src: self.my_id().expect("expected to know my_id by this point"),
dst: destination, dst: destination,
data, data,
@ -143,7 +143,7 @@ impl Drop for Peer {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{thread, time::Duration}; use std::time::Duration;
use tracing::info; use tracing::info;
@ -203,14 +203,14 @@ mod test {
assert_eq!(host.shared.remote_peers.len(), 1); assert_eq!(host.shared.remote_peers.len(), 1);
let peer1 = Peer::connect(addr, settings.clone()).unwrap(); let peer1 = Peer::connect(addr, settings.clone()).unwrap();
let peer2 = Peer::connect(addr, settings.clone()).unwrap(); let peer2 = Peer::connect(addr, settings.clone()).unwrap();
thread::sleep(Duration::from_millis(10)); tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(host.shared.remote_peers.len(), 3); assert_eq!(host.shared.remote_peers.len(), 3);
let data = vec![123, 112, 51, 23]; let data = vec![123, 112, 51, 23];
peer1 peer1
.broadcast(data.clone(), Reliability::Reliable) .broadcast(data.clone(), Reliability::Reliable)
.unwrap(); .unwrap();
thread::sleep(Duration::from_millis(10)); tokio::time::sleep(Duration::from_millis(10)).await;
let host_events: Vec<_> = dbg!(host.recv().collect()); let host_events: Vec<_> = dbg!(host.recv().collect());
let peer1_events: Vec<_> = dbg!(peer1.recv().collect()); let peer1_events: Vec<_> = dbg!(peer1.recv().collect());
@ -239,7 +239,7 @@ mod test {
}); });
let addr = "127.0.0.1:56003".parse().unwrap(); let addr = "127.0.0.1:56003".parse().unwrap();
let host = Peer::host(addr, settings.clone()).unwrap(); let host = Peer::host(addr, settings.clone()).unwrap();
thread::sleep(Duration::from_millis(10)); tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!( assert_eq!(
host.recv().next(), host.recv().next(),
Some(NetworkEvent::PeerConnected(PeerId(0))) Some(NetworkEvent::PeerConnected(PeerId(0)))