Implement automatic client reconnection with exponential backoff and heartbeat timeout

- Add heartbeat timeout to client control connection using server heartbeats for dead connection detection
- Introduce exponential backoff with jitter for reconnection delays
- Add CLI flags: --no-reconnect to disable auto-reconnect, --max-reconnect-delay to configure backoff cap
- Classify authentication errors as fatal (never retried), all others retried automatically
- Configure TCP keepalive on control connections for OS-level dead connection detection
- Update documentation (README.md, CLAUDE.md) to describe reconnection behavior and new flags
- Add unit tests for backoff logic and error classification
This commit is contained in:
kfirfer 2026-02-17 14:35:36 +07:00
parent 042fa78742
commit a13e03372e
No known key found for this signature in database
GPG key ID: B2103FE1471D8A5E
9 changed files with 438 additions and 126 deletions

View file

@ -8,7 +8,10 @@ use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
use crate::shared::{
set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, HEARTBEAT_TIMEOUT,
NETWORK_TIMEOUT,
};
/// State structure for the client.
pub struct Client {
@ -40,7 +43,9 @@ impl Client {
port: u16,
secret: Option<&str>,
) -> Result<Self> {
let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
let tcp_stream = connect_with_timeout(to, CONTROL_PORT).await?;
set_tcp_keepalive(&tcp_stream)?;
let mut stream = Delimited::new(tcp_stream);
let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
@ -79,25 +84,32 @@ impl Client {
let mut conn = self.conn.take().unwrap();
let this = Arc::new(self);
loop {
match conn.recv().await? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
Some(ServerMessage::Connection(id)) => {
let this = Arc::clone(&this);
tokio::spawn(
async move {
info!("new connection");
match this.handle_connection(id).await {
Ok(_) => info!("connection exited"),
Err(err) => warn!(%err, "connection exited with error"),
}
}
.instrument(info_span!("proxy", %id)),
);
match timeout(HEARTBEAT_TIMEOUT, conn.recv()).await {
Err(_elapsed) => {
// No message received for HEARTBEAT_TIMEOUT seconds.
// Server sends heartbeats every 500ms, so connection is dead.
bail!("heartbeat timeout, connection to server lost");
}
Some(ServerMessage::Error(err)) => error!(%err, "server error"),
None => return Ok(()),
Ok(msg) => match msg? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
Some(ServerMessage::Connection(id)) => {
let this = Arc::clone(&this);
tokio::spawn(
async move {
info!("new connection");
match this.handle_connection(id).await {
Ok(_) => info!("connection exited"),
Err(err) => warn!(%err, "connection exited with error"),
}
}
.instrument(info_span!("proxy", %id)),
);
}
Some(ServerMessage::Error(err)) => error!(%err, "server error"),
None => bail!("server closed connection"),
},
}
}
}

View file

@ -1,8 +1,11 @@
use std::net::IpAddr;
use std::time::Duration;
use anyhow::Result;
use bore_cli::shared::ExponentialBackoff;
use bore_cli::{client::Client, server::Server};
use clap::{error::ErrorKind, CommandFactory, Parser, Subcommand};
use tracing::{info, warn};
#[derive(Parser, Debug)]
#[clap(author, version, about)]
@ -34,6 +37,14 @@ enum Command {
/// Optional secret for authentication.
#[clap(short, long, env = "BORE_SECRET", hide_env_values = true)]
secret: Option<String>,
/// Disable automatic reconnection on connection loss.
#[clap(long, default_value_t = false)]
no_reconnect: bool,
/// Maximum delay between reconnection attempts, in seconds.
#[clap(long, default_value_t = 64, value_name = "SECONDS")]
max_reconnect_delay: u64,
},
/// Runs the remote proxy server.
@ -60,6 +71,15 @@ enum Command {
},
}
/// Check if an error is an authentication error that should not be retried.
fn is_auth_error(err: &anyhow::Error) -> bool {
let msg = format!("{err:#}");
msg.contains("server requires authentication")
|| msg.contains("invalid secret")
|| msg.contains("server requires secret")
|| msg.contains("expected authentication challenge")
}
#[tokio::main]
async fn run(command: Command) -> Result<()> {
match command {
@ -69,9 +89,56 @@ async fn run(command: Command) -> Result<()> {
to,
port,
secret,
no_reconnect,
max_reconnect_delay,
} => {
// First attempt — propagate errors directly for immediate feedback
let client = Client::new(&local_host, local_port, &to, port, secret.as_deref()).await?;
client.listen().await?;
if no_reconnect {
// Legacy behavior: exit on any disconnection
client.listen().await?;
} else {
// Reconnection mode: retry on transient failures
let mut backoff = ExponentialBackoff::new(
Duration::from_secs(1),
Duration::from_secs(max_reconnect_delay),
);
// Run the first listen (we already have a connected client)
if let Err(e) = client.listen().await {
warn!("connection lost: {e:#}");
}
// Reconnection loop
loop {
let delay = backoff.next_delay();
info!("reconnecting in {delay:.1?}...");
tokio::time::sleep(delay).await;
match Client::new(&local_host, local_port, &to, port, secret.as_deref()).await {
Ok(client) => {
backoff.reset();
info!("reconnected successfully");
match client.listen().await {
Ok(()) => unreachable!("listen() now always returns Err"),
Err(e) => {
if is_auth_error(&e) {
return Err(e);
}
warn!("connection lost: {e:#}");
}
}
}
Err(e) => {
if is_auth_error(&e) {
return Err(e);
}
warn!("reconnection failed: {e:#}");
}
}
}
}
}
Command::Server {
min_port,
@ -100,3 +167,37 @@ fn main() -> Result<()> {
tracing_subscriber::fmt::init();
run(Args::parse().command)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_error_detection() {
// Fatal auth errors — should NOT be retried
assert!(is_auth_error(&anyhow::anyhow!(
"server requires authentication, but no client secret was provided"
)));
assert!(is_auth_error(&anyhow::anyhow!(
"server error: invalid secret"
)));
assert!(is_auth_error(&anyhow::anyhow!(
"server error: server requires secret, but no secret was provided"
)));
assert!(is_auth_error(&anyhow::anyhow!(
"expected authentication challenge, but no secret was required"
)));
// Retriable errors — should be retried
assert!(!is_auth_error(&anyhow::anyhow!(
"could not connect to server:7835"
)));
assert!(!is_auth_error(&anyhow::anyhow!(
"heartbeat timeout, connection to server lost"
)));
assert!(!is_auth_error(&anyhow::anyhow!(
"server error: port already in use"
)));
assert!(!is_auth_error(&anyhow::anyhow!("server closed connection")));
}
}

View file

@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
use crate::shared::{set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
/// State structure for the server.
pub struct Server {
@ -116,6 +116,7 @@ impl Server {
}
async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
set_tcp_keepalive(&stream)?;
let mut stream = Delimited::new(stream);
if let Some(auth) = &self.auth {
if let Err(err) = auth.server_handshake(&mut stream).await {

View file

@ -5,7 +5,9 @@ use std::time::Duration;
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
use tracing::trace;
@ -20,6 +22,12 @@ pub const MAX_FRAME_LENGTH: usize = 256;
/// Timeout for network connections and initial protocol messages.
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
/// Timeout for detecting a dead control connection.
///
/// The server sends heartbeats every 500ms. If no message is received within
/// this duration, the connection is considered dead.
pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(8);
/// A message from the client on the control connection.
#[derive(Debug, Serialize, Deserialize)]
pub enum ClientMessage {
@ -92,8 +100,95 @@ impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
Ok(())
}
/// Get a reference to the underlying transport stream.
pub fn get_ref(&self) -> &U {
self.0.get_ref()
}
/// Consume this object, returning current buffers and the inner transport.
pub fn into_parts(self) -> FramedParts<U, AnyDelimiterCodec> {
self.0.into_parts()
}
}
/// Simple exponential backoff with jitter for reconnection delays.
pub struct ExponentialBackoff {
current: Duration,
base: Duration,
max: Duration,
}
impl ExponentialBackoff {
/// Create a new exponential backoff starting at `base` delay, capped at `max`.
pub fn new(base: Duration, max: Duration) -> Self {
Self {
current: base,
base,
max,
}
}
/// Get the next delay and advance the backoff state.
/// Includes random jitter of +/- 25% to prevent thundering herd.
pub fn next_delay(&mut self) -> Duration {
let delay = self.current;
self.current = (self.current * 2).min(self.max);
// Add jitter: multiply by random factor between 0.75 and 1.25
let jitter_factor = 0.75 + fastrand::f64() * 0.5;
delay.mul_f64(jitter_factor)
}
/// Reset backoff to initial delay (call after successful connection).
pub fn reset(&mut self) {
self.current = self.base;
}
}
/// Configure TCP keepalive on a stream for faster dead connection detection.
///
/// This sets the OS to start probing after 30s of idle, probe every 10s,
/// and give up after 3 failed probes (~60s total to detect a dead connection).
pub fn set_tcp_keepalive(stream: &TcpStream) -> Result<()> {
let sock_ref = SockRef::from(stream);
let keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10))
.with_retries(3);
sock_ref
.set_tcp_keepalive(&keepalive)
.context("failed to set TCP keepalive")?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_sequence() {
let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(30));
// Delays should roughly double: 1, 2, 4, 8, 16, 30 (capped), 30, ...
// With jitter, each delay is between 0.75x and 1.25x the base
for expected_base in [1, 2, 4, 8, 16, 30, 30] {
let delay = backoff.next_delay();
let min = Duration::from_secs(expected_base).mul_f64(0.75);
let max = Duration::from_secs(expected_base).mul_f64(1.25);
assert!(
delay >= min && delay <= max,
"delay {delay:?} out of range [{min:?}, {max:?}]"
);
}
}
#[test]
fn test_backoff_reset() {
let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60));
backoff.next_delay(); // 1s
backoff.next_delay(); // 2s
backoff.next_delay(); // 4s
backoff.reset();
let delay = backoff.next_delay();
// After reset, should be back to ~1s (with jitter)
assert!(delay < Duration::from_secs(2));
}
}