mirror of
https://github.com/ekzhang/bore.git
synced 2026-04-19 07:00:18 +00:00
Implement automatic client reconnection with exponential backoff and heartbeat timeout
- Add heartbeat timeout to client control connection using server heartbeats for dead connection detection - Introduce exponential backoff with jitter for reconnection delays - Add CLI flags: --no-reconnect to disable auto-reconnect, --max-reconnect-delay to configure backoff cap - Classify authentication errors as fatal (never retried), all others retried automatically - Configure TCP keepalive on control connections for OS-level dead connection detection - Update documentation (README.md, CLAUDE.md) to describe reconnection behavior and new flags - Add unit tests for backoff logic and error classification
This commit is contained in:
parent
042fa78742
commit
a13e03372e
9 changed files with 438 additions and 126 deletions
|
|
@ -8,7 +8,10 @@ use tracing::{error, info, info_span, warn, Instrument};
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
|
||||
use crate::shared::{
|
||||
set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, HEARTBEAT_TIMEOUT,
|
||||
NETWORK_TIMEOUT,
|
||||
};
|
||||
|
||||
/// State structure for the client.
|
||||
pub struct Client {
|
||||
|
|
@ -40,7 +43,9 @@ impl Client {
|
|||
port: u16,
|
||||
secret: Option<&str>,
|
||||
) -> Result<Self> {
|
||||
let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
|
||||
let tcp_stream = connect_with_timeout(to, CONTROL_PORT).await?;
|
||||
set_tcp_keepalive(&tcp_stream)?;
|
||||
let mut stream = Delimited::new(tcp_stream);
|
||||
let auth = secret.map(Authenticator::new);
|
||||
if let Some(auth) = &auth {
|
||||
auth.client_handshake(&mut stream).await?;
|
||||
|
|
@ -79,25 +84,32 @@ impl Client {
|
|||
let mut conn = self.conn.take().unwrap();
|
||||
let this = Arc::new(self);
|
||||
loop {
|
||||
match conn.recv().await? {
|
||||
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
|
||||
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
|
||||
Some(ServerMessage::Heartbeat) => (),
|
||||
Some(ServerMessage::Connection(id)) => {
|
||||
let this = Arc::clone(&this);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
info!("new connection");
|
||||
match this.handle_connection(id).await {
|
||||
Ok(_) => info!("connection exited"),
|
||||
Err(err) => warn!(%err, "connection exited with error"),
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("proxy", %id)),
|
||||
);
|
||||
match timeout(HEARTBEAT_TIMEOUT, conn.recv()).await {
|
||||
Err(_elapsed) => {
|
||||
// No message received for HEARTBEAT_TIMEOUT seconds.
|
||||
// Server sends heartbeats every 500ms, so connection is dead.
|
||||
bail!("heartbeat timeout, connection to server lost");
|
||||
}
|
||||
Some(ServerMessage::Error(err)) => error!(%err, "server error"),
|
||||
None => return Ok(()),
|
||||
Ok(msg) => match msg? {
|
||||
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
|
||||
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
|
||||
Some(ServerMessage::Heartbeat) => (),
|
||||
Some(ServerMessage::Connection(id)) => {
|
||||
let this = Arc::clone(&this);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
info!("new connection");
|
||||
match this.handle_connection(id).await {
|
||||
Ok(_) => info!("connection exited"),
|
||||
Err(err) => warn!(%err, "connection exited with error"),
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("proxy", %id)),
|
||||
);
|
||||
}
|
||||
Some(ServerMessage::Error(err)) => error!(%err, "server error"),
|
||||
None => bail!("server closed connection"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
103
src/main.rs
103
src/main.rs
|
|
@ -1,8 +1,11 @@
|
|||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use bore_cli::shared::ExponentialBackoff;
|
||||
use bore_cli::{client::Client, server::Server};
|
||||
use clap::{error::ErrorKind, CommandFactory, Parser, Subcommand};
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about)]
|
||||
|
|
@ -34,6 +37,14 @@ enum Command {
|
|||
/// Optional secret for authentication.
|
||||
#[clap(short, long, env = "BORE_SECRET", hide_env_values = true)]
|
||||
secret: Option<String>,
|
||||
|
||||
/// Disable automatic reconnection on connection loss.
|
||||
#[clap(long, default_value_t = false)]
|
||||
no_reconnect: bool,
|
||||
|
||||
/// Maximum delay between reconnection attempts, in seconds.
|
||||
#[clap(long, default_value_t = 64, value_name = "SECONDS")]
|
||||
max_reconnect_delay: u64,
|
||||
},
|
||||
|
||||
/// Runs the remote proxy server.
|
||||
|
|
@ -60,6 +71,15 @@ enum Command {
|
|||
},
|
||||
}
|
||||
|
||||
/// Check if an error is an authentication error that should not be retried.
|
||||
fn is_auth_error(err: &anyhow::Error) -> bool {
|
||||
let msg = format!("{err:#}");
|
||||
msg.contains("server requires authentication")
|
||||
|| msg.contains("invalid secret")
|
||||
|| msg.contains("server requires secret")
|
||||
|| msg.contains("expected authentication challenge")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn run(command: Command) -> Result<()> {
|
||||
match command {
|
||||
|
|
@ -69,9 +89,56 @@ async fn run(command: Command) -> Result<()> {
|
|||
to,
|
||||
port,
|
||||
secret,
|
||||
no_reconnect,
|
||||
max_reconnect_delay,
|
||||
} => {
|
||||
// First attempt — propagate errors directly for immediate feedback
|
||||
let client = Client::new(&local_host, local_port, &to, port, secret.as_deref()).await?;
|
||||
client.listen().await?;
|
||||
|
||||
if no_reconnect {
|
||||
// Legacy behavior: exit on any disconnection
|
||||
client.listen().await?;
|
||||
} else {
|
||||
// Reconnection mode: retry on transient failures
|
||||
let mut backoff = ExponentialBackoff::new(
|
||||
Duration::from_secs(1),
|
||||
Duration::from_secs(max_reconnect_delay),
|
||||
);
|
||||
|
||||
// Run the first listen (we already have a connected client)
|
||||
if let Err(e) = client.listen().await {
|
||||
warn!("connection lost: {e:#}");
|
||||
}
|
||||
|
||||
// Reconnection loop
|
||||
loop {
|
||||
let delay = backoff.next_delay();
|
||||
info!("reconnecting in {delay:.1?}...");
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
match Client::new(&local_host, local_port, &to, port, secret.as_deref()).await {
|
||||
Ok(client) => {
|
||||
backoff.reset();
|
||||
info!("reconnected successfully");
|
||||
match client.listen().await {
|
||||
Ok(()) => unreachable!("listen() now always returns Err"),
|
||||
Err(e) => {
|
||||
if is_auth_error(&e) {
|
||||
return Err(e);
|
||||
}
|
||||
warn!("connection lost: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if is_auth_error(&e) {
|
||||
return Err(e);
|
||||
}
|
||||
warn!("reconnection failed: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Command::Server {
|
||||
min_port,
|
||||
|
|
@ -100,3 +167,37 @@ fn main() -> Result<()> {
|
|||
tracing_subscriber::fmt::init();
|
||||
run(Args::parse().command)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_auth_error_detection() {
|
||||
// Fatal auth errors — should NOT be retried
|
||||
assert!(is_auth_error(&anyhow::anyhow!(
|
||||
"server requires authentication, but no client secret was provided"
|
||||
)));
|
||||
assert!(is_auth_error(&anyhow::anyhow!(
|
||||
"server error: invalid secret"
|
||||
)));
|
||||
assert!(is_auth_error(&anyhow::anyhow!(
|
||||
"server error: server requires secret, but no secret was provided"
|
||||
)));
|
||||
assert!(is_auth_error(&anyhow::anyhow!(
|
||||
"expected authentication challenge, but no secret was required"
|
||||
)));
|
||||
|
||||
// Retriable errors — should be retried
|
||||
assert!(!is_auth_error(&anyhow::anyhow!(
|
||||
"could not connect to server:7835"
|
||||
)));
|
||||
assert!(!is_auth_error(&anyhow::anyhow!(
|
||||
"heartbeat timeout, connection to server lost"
|
||||
)));
|
||||
assert!(!is_auth_error(&anyhow::anyhow!(
|
||||
"server error: port already in use"
|
||||
)));
|
||||
assert!(!is_auth_error(&anyhow::anyhow!("server closed connection")));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument};
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
||||
use crate::shared::{set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
||||
|
||||
/// State structure for the server.
|
||||
pub struct Server {
|
||||
|
|
@ -116,6 +116,7 @@ impl Server {
|
|||
}
|
||||
|
||||
async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
|
||||
set_tcp_keepalive(&stream)?;
|
||||
let mut stream = Delimited::new(stream);
|
||||
if let Some(auth) = &self.auth {
|
||||
if let Err(err) = auth.server_handshake(&mut stream).await {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ use std::time::Duration;
|
|||
use anyhow::{Context, Result};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
|
||||
use tracing::trace;
|
||||
|
|
@ -20,6 +22,12 @@ pub const MAX_FRAME_LENGTH: usize = 256;
|
|||
/// Timeout for network connections and initial protocol messages.
|
||||
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
|
||||
|
||||
/// Timeout for detecting a dead control connection.
|
||||
///
|
||||
/// The server sends heartbeats every 500ms. If no message is received within
|
||||
/// this duration, the connection is considered dead.
|
||||
pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(8);
|
||||
|
||||
/// A message from the client on the control connection.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum ClientMessage {
|
||||
|
|
@ -92,8 +100,95 @@ impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying transport stream.
|
||||
pub fn get_ref(&self) -> &U {
|
||||
self.0.get_ref()
|
||||
}
|
||||
|
||||
/// Consume this object, returning current buffers and the inner transport.
|
||||
pub fn into_parts(self) -> FramedParts<U, AnyDelimiterCodec> {
|
||||
self.0.into_parts()
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple exponential backoff with jitter for reconnection delays.
|
||||
pub struct ExponentialBackoff {
|
||||
current: Duration,
|
||||
base: Duration,
|
||||
max: Duration,
|
||||
}
|
||||
|
||||
impl ExponentialBackoff {
|
||||
/// Create a new exponential backoff starting at `base` delay, capped at `max`.
|
||||
pub fn new(base: Duration, max: Duration) -> Self {
|
||||
Self {
|
||||
current: base,
|
||||
base,
|
||||
max,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the next delay and advance the backoff state.
|
||||
/// Includes random jitter of +/- 25% to prevent thundering herd.
|
||||
pub fn next_delay(&mut self) -> Duration {
|
||||
let delay = self.current;
|
||||
self.current = (self.current * 2).min(self.max);
|
||||
// Add jitter: multiply by random factor between 0.75 and 1.25
|
||||
let jitter_factor = 0.75 + fastrand::f64() * 0.5;
|
||||
delay.mul_f64(jitter_factor)
|
||||
}
|
||||
|
||||
/// Reset backoff to initial delay (call after successful connection).
|
||||
pub fn reset(&mut self) {
|
||||
self.current = self.base;
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure TCP keepalive on a stream for faster dead connection detection.
|
||||
///
|
||||
/// This sets the OS to start probing after 30s of idle, probe every 10s,
|
||||
/// and give up after 3 failed probes (~60s total to detect a dead connection).
|
||||
pub fn set_tcp_keepalive(stream: &TcpStream) -> Result<()> {
|
||||
let sock_ref = SockRef::from(stream);
|
||||
let keepalive = TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(30))
|
||||
.with_interval(Duration::from_secs(10))
|
||||
.with_retries(3);
|
||||
sock_ref
|
||||
.set_tcp_keepalive(&keepalive)
|
||||
.context("failed to set TCP keepalive")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_backoff_sequence() {
|
||||
let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(30));
|
||||
// Delays should roughly double: 1, 2, 4, 8, 16, 30 (capped), 30, ...
|
||||
// With jitter, each delay is between 0.75x and 1.25x the base
|
||||
for expected_base in [1, 2, 4, 8, 16, 30, 30] {
|
||||
let delay = backoff.next_delay();
|
||||
let min = Duration::from_secs(expected_base).mul_f64(0.75);
|
||||
let max = Duration::from_secs(expected_base).mul_f64(1.25);
|
||||
assert!(
|
||||
delay >= min && delay <= max,
|
||||
"delay {delay:?} out of range [{min:?}, {max:?}]"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_reset() {
|
||||
let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60));
|
||||
backoff.next_delay(); // 1s
|
||||
backoff.next_delay(); // 2s
|
||||
backoff.next_delay(); // 4s
|
||||
backoff.reset();
|
||||
let delay = backoff.next_delay();
|
||||
// After reset, should be back to ~1s (with jitter)
|
||||
assert!(delay < Duration::from_secs(2));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue