diff --git a/Cargo.lock b/Cargo.lock index 6fa53d1..82dd375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,13 +118,16 @@ checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" [[package]] name = "backoff" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fe17f59a06fe8b87a6fc8bf53bb70b3aba76d7685f432487a68cd5552853625" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" dependencies = [ + "futures-core", "getrandom 0.2.4", "instant", + "pin-project-lite", "rand", + "tokio", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d7f1e85..baff0ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ bincode = "1" lazy_static = "1.4" hex = "0.4" rand = "0.8" -backoff = "0.3" +backoff = { version = "0.4", features = ["tokio"] } tracing = "0.1" tracing-subscriber = "0.2" socket2 = { version = "0.4", features = ["all"] } diff --git a/src/client.rs b/src/client.rs index 60480c1..e806e48 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,6 @@ use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType}; use crate::config_watcher::ServiceChange; -use crate::helper::{retry_notify, udp_connect}; +use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, @@ -8,8 +8,8 @@ use crate::protocol::{ }; use crate::transport::{SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; -use backoff::backoff::Backoff; use backoff::ExponentialBackoff; +use backoff::{backoff::Backoff, future::retry_notify}; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; @@ -159,21 +159,22 @@ async fn do_data_channel_handshake( args: Arc>, ) -> Result { // Retry at least every 100ms, at most for 10 seconds - let mut backoff = ExponentialBackoff { + let backoff = ExponentialBackoff { max_interval: Duration::from_millis(100), max_elapsed_time: Some(Duration::from_secs(10)), ..Default::default() }; // Connect to remote_addr - let mut conn: T::Stream = retry_notify!( + let mut conn: T::Stream = retry_notify( backoff, - { + || async { match args .connector .connect(&args.remote_addr) .await .with_context(|| format!("Failed to connect to {}", &args.remote_addr)) + .map_err(backoff::Error::transient) { Ok(conn) => { T::hint(&conn, args.socket_opts); @@ -184,8 +185,9 @@ async fn do_data_channel_handshake( }, |e, duration| { warn!("{:#}. Retry in {:?}", e, duration); - } - )?; + }, + ) + .await?; // Send nonce let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap(); diff --git a/src/helper.rs b/src/helper.rs index 70ceaeb..e52de66 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,8 +1,11 @@ -use std::{net::SocketAddr, time::Duration}; - use anyhow::{anyhow, Result}; +use backoff::{backoff::Backoff, Notify}; use socket2::{SockRef, TcpKeepalive}; -use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket}; +use std::{future::Future, net::SocketAddr, time::Duration}; +use tokio::{ + net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket}, + sync::broadcast, +}; use tracing::trace; // Tokio hesitates to expose this option...So we have to do it on our own :( @@ -52,62 +55,26 @@ pub async fn udp_connect(addr: A) -> Result { Ok(s) } -/// Almost same as backoff::future::retry_notify -/// But directly expands to a loop -macro_rules! retry_notify { - ($b: expr, $func: expr, $notify: expr) => { - loop { - match $func { - Ok(v) => break Ok(v), - Err(e) => match $b.next_backoff() { - Some(duration) => { - $notify(e, duration); - tokio::time::sleep(duration).await; - } - None => break Err(e), - }, - } +// Wrapper of retry_notify +pub async fn retry_notify_with_deadline( + backoff: B, + operation: Fn, + notify: N, + deadline: &mut broadcast::Receiver, +) -> Result +where + E: std::error::Error + Send + Sync + 'static, + B: Backoff, + Fn: FnMut() -> Fut, + Fut: Future>>, + N: Notify, +{ + tokio::select! { + v = backoff::future::retry_notify(backoff, operation, notify) => { + v.map_err(anyhow::Error::new) } - }; -} - -pub(crate) use retry_notify; - -#[cfg(test)] -mod test { - use super::*; - use backoff::{backoff::Backoff, ExponentialBackoff}; - #[tokio::test] - async fn test_retry_notify() { - let tests = [(3, Ok(())), (5, Err("try again"))]; - for (try_succ, expected) in tests { - let mut b = ExponentialBackoff { - current_interval: Duration::from_millis(100), - initial_interval: Duration::from_millis(100), - max_elapsed_time: Some(Duration::from_millis(210)), - randomization_factor: 0.0, - multiplier: 1.0, - ..Default::default() - }; - - let mut notify_count = 0; - let mut try_count = 0; - let ret: Result<(), &str> = retry_notify!( - b, - { - try_count += 1; - if try_count == try_succ { - Ok(()) - } else { - Err("try again") - } - }, - |e, duration| { - notify_count += 1; - println!("{}: {}, {:?}", notify_count, e, duration); - } - ); - assert_eq!(ret, expected); + _ = deadline.recv() => { + Err(anyhow!("shutdown")) } } } diff --git a/src/server.rs b/src/server.rs index 02a3e37..abd181b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,7 @@ use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType}; use crate::config_watcher::ServiceChange; use crate::constants::{listen_backoff, UDP_BUFFER_SIZE}; -use crate::helper::retry_notify; +use crate::helper::retry_notify_with_deadline; use crate::multi_map::MultiMap; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; use crate::protocol::{ @@ -509,21 +509,15 @@ fn tcp_listen_and_send( let (tx, rx) = mpsc::channel(CHAN_SIZE); tokio::spawn(async move { - let l = retry_notify!(listen_backoff(), { - match shutdown_rx.try_recv() { - Err(broadcast::error::TryRecvError::Closed) => Ok(None), - _ => TcpListener::bind(&addr).await.map(Some) - } + let l = retry_notify_with_deadline(listen_backoff(), || async { + Ok(TcpListener::bind(&addr).await?) }, |e, duration| { error!("{:#}. Retry in {:?}", e, duration); - }) + }, &mut shutdown_rx).await .with_context(|| "Failed to listen for the service"); let l: TcpListener = match l { - Ok(v) => match v { - Some(v) => v, - None => return - }, + Ok(v) => v, Err(e) => { error!("{:#}", e); return; @@ -628,27 +622,16 @@ async fn run_udp_connection_pool( ) -> Result<()> { // TODO: Load balance - let l = retry_notify!( + let l = retry_notify_with_deadline( listen_backoff(), - { - match shutdown_rx.try_recv() { - Err(broadcast::error::TryRecvError::Closed) => Ok(None), - _ => UdpSocket::bind(&bind_addr).await.map(Some), - } - }, + || async { Ok(UdpSocket::bind(&bind_addr).await?) }, |e, duration| { warn!("{:#}. Retry in {:?}", e, duration); - } - ) - .with_context(|| "Failed to listen for the service"); - - let l = match l { - Ok(v) => match v { - Some(l) => l, - None => return Ok(()), }, - Err(e) => return Err(e), - }; + &mut shutdown_rx, + ) + .await + .with_context(|| "Failed to listen for the service")?; info!("Listening at {}", &bind_addr);