Make most tests pass

This commit is contained in:
IQuant 2024-09-06 20:39:24 +03:00
parent 868cc1ffaa
commit f4b5eef5fb
3 changed files with 14 additions and 21 deletions

View file

@ -78,7 +78,6 @@ impl DirectPeer {
shared
.internal_events_s
.send(InternalEvent::Disconnected(remote_id))
.await
.ok();
}
@ -126,7 +125,7 @@ impl DirectPeer {
.map_err(|_err| DirectConnectionError::InitialExchangeFailed)?;
debug!("Got peer id {peer_id}");
let (send_stream, recv_stream) = connection.open_bi().await?;
let (send_stream, recv_stream) = connection.accept_bi().await?;
tokio::spawn(Self::recv_task(shared, recv_stream, PeerId::HOST));
debug!("Client: spawned recv task");
@ -163,7 +162,7 @@ enum InternalEvent {
pub(crate) struct Shared {
pub inbound_channel: Channel<NetworkEvent>,
pub outbound_messages_s: tokio::sync::mpsc::Sender<OutboundMessage>,
pub outbound_messages_s: tokio::sync::mpsc::UnboundedSender<OutboundMessage>,
pub keep_alive: AtomicBool,
pub peer_state: AtomicCell<PeerState>,
pub remote_peers: DashMap<PeerId, RemotePeer>,
@ -172,7 +171,7 @@ pub(crate) struct Shared {
// ConnectionManager-specific stuff
direct_peers: DashMap<PeerId, DirectPeer>,
internal_incoming_messages_s: tokio::sync::mpsc::Sender<(PeerId, InternalMessage)>,
internal_events_s: tokio::sync::mpsc::Sender<InternalEvent>,
internal_events_s: tokio::sync::mpsc::UnboundedSender<InternalEvent>,
}
pub(crate) struct ConnectionManager {
@ -181,8 +180,8 @@ pub(crate) struct ConnectionManager {
host_conn: Option<DirectPeer>,
is_server: bool,
incoming_messages_r: tokio::sync::mpsc::Receiver<(PeerId, InternalMessage)>,
outbound_messages_r: tokio::sync::mpsc::Receiver<OutboundMessage>,
internal_events_r: tokio::sync::mpsc::Receiver<InternalEvent>,
outbound_messages_r: tokio::sync::mpsc::UnboundedReceiver<OutboundMessage>,
internal_events_r: tokio::sync::mpsc::UnboundedReceiver<InternalEvent>,
}
impl ConnectionManager {
@ -194,8 +193,8 @@ impl ConnectionManager {
let is_server = host_addr.is_none();
let (internal_incoming_messages_s, incoming_messages_r) = tokio::sync::mpsc::channel(512);
let (outbound_messages_s, outbound_messages_r) = tokio::sync::mpsc::channel(512);
let (internal_events_s, internal_events_r) = tokio::sync::mpsc::channel(512);
let (outbound_messages_s, outbound_messages_r) = tokio::sync::mpsc::unbounded_channel();
let (internal_events_s, internal_events_r) = tokio::sync::mpsc::unbounded_channel();
let shared = Arc::new(Shared {
inbound_channel: unbounded(),
@ -254,7 +253,6 @@ impl ConnectionManager {
shared
.internal_events_s
.send(InternalEvent::Connected(PeerId(peer_id_counter)))
.await
.expect("channel to be open");
peer_id_counter += 1;
}
@ -280,20 +278,17 @@ impl ConnectionManager {
}))
.expect("channel to be open");
}
// TODO this might deadlock if internal_events_s is full.
InternalMessage::RemoteConnected(peer_id) => {
debug!("Got notified of peer {peer_id}");
self.shared
.internal_events_s
.send(InternalEvent::Connected(peer_id))
.await
.expect("channel to be open");
}
InternalMessage::RemoteDisconnected(peer_id) => self
.shared
.internal_events_s
.send(InternalEvent::Disconnected(peer_id))
.await
.expect("channel to be open"),
}
}
@ -399,7 +394,6 @@ impl ConnectionManager {
self.shared
.internal_events_s
.send(InternalEvent::Connected(host_conn.remote_id))
.await
.expect("channel to be open");
self.host_conn = Some(host_conn);
}

View file

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use bitcode::{DecodeOwned, Encode};
use quinn::{RecvStream, SendStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace};
use tracing::trace;
use super::DirectConnectionError;
@ -42,7 +42,6 @@ impl<Msg: Encode> SendMessageStream<Msg> {
}
pub(crate) async fn send(&mut self, msg: &Msg) -> Result<(), DirectConnectionError> {
trace!("Sending message");
let msg = bitcode::encode(msg);
self.send_raw(&msg).await
}
@ -62,7 +61,7 @@ impl<Msg: DecodeOwned> RecvMessageStream<Msg> {
.read_u32()
.await
.map_err(|_err| DirectConnectionError::MessageIoFailed)?;
trace!("Expecting message of {len}");
trace!("Expecting message of len {len}");
let mut buf = vec![0; len as usize];
self.inner
.read_exact(&mut buf)

View file

@ -91,7 +91,7 @@ impl Peer {
}
self.shared
.outbound_messages_s
.blocking_send(OutboundMessage {
.send(OutboundMessage {
src: self.my_id().expect("expected to know my_id by this point"),
dst: destination,
data,
@ -143,7 +143,7 @@ impl Drop for Peer {
#[cfg(test)]
mod test {
use std::{thread, time::Duration};
use std::time::Duration;
use tracing::info;
@ -203,14 +203,14 @@ mod test {
assert_eq!(host.shared.remote_peers.len(), 1);
let peer1 = Peer::connect(addr, settings.clone()).unwrap();
let peer2 = Peer::connect(addr, settings.clone()).unwrap();
thread::sleep(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(host.shared.remote_peers.len(), 3);
let data = vec![123, 112, 51, 23];
peer1
.broadcast(data.clone(), Reliability::Reliable)
.unwrap();
thread::sleep(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(10)).await;
let host_events: Vec<_> = dbg!(host.recv().collect());
let peer1_events: Vec<_> = dbg!(peer1.recv().collect());
@ -239,7 +239,7 @@ mod test {
});
let addr = "127.0.0.1:56003".parse().unwrap();
let host = Peer::host(addr, settings.clone()).unwrap();
thread::sleep(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(
host.recv().next(),
Some(NetworkEvent::PeerConnected(PeerId(0)))