diff --git a/src/client.rs b/src/client.rs index dc9e75b..7591ab1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; +use crate::config_watcher::ServiceChangeEvent; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ @@ -16,7 +17,7 @@ use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration}; -use tracing::{debug, error, info, instrument, Instrument, Span}; +use tracing::{debug, error, info, instrument, warn, Instrument, Span}; #[cfg(feature = "noise")] use crate::transport::NoiseTransport; @@ -26,7 +27,11 @@ use crate::transport::TlsTransport; use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT}; // The entrypoint of running a client -pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver) -> Result<()> { +pub async fn run_client( + config: &Config, + shutdown_rx: broadcast::Receiver, + service_rx: mpsc::Receiver, +) -> Result<()> { let config = match &config.client { Some(v) => v, None => { @@ -37,13 +42,13 @@ pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver) match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; - client.run(shutdown_rx).await + client.run(shutdown_rx, service_rx).await } TransportType::Tls => { #[cfg(feature = "tls")] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx).await + client.run(shutdown_rx, service_rx).await } #[cfg(not(feature = "tls"))] crate::helper::feature_not_compile("tls") @@ -52,7 +57,7 @@ pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver) #[cfg(feature = "noise")] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx).await + client.run(shutdown_rx, service_rx).await } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -85,7 +90,11 @@ impl<'a, T: 'static + Transport> Client<'a, T> { } // The entrypoint of Client - async fn run(&mut self, mut shutdown_rx: broadcast::Receiver) -> Result<()> { + async fn run( + &mut self, + mut shutdown_rx: broadcast::Receiver, + mut service_rx: mpsc::Receiver, + ) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined let handle = ControlChannelHandle::new( @@ -96,7 +105,6 @@ impl<'a, T: 'static + Transport> Client<'a, T> { self.service_handles.insert(name.clone(), handle); } - // TODO: Maybe wait for a config change signal for hot reloading // Wait for the shutdown signal loop { tokio::select! { @@ -109,6 +117,25 @@ impl<'a, T: 'static + Transport> Client<'a, T> { } break; }, + e = service_rx.recv() => { + if let Some(e) = e { + match e { + ServiceChangeEvent::ClientAdd(s)=> { + let name = s.name.clone(); + let handle = ControlChannelHandle::new( + s, + self.config.remote_addr.clone(), + self.transport.clone(), + ); + let _ = self.service_handles.insert(name, handle); + }, + ServiceChangeEvent::ClientDelete(s)=> { + let _ = self.service_handles.remove(&s); + }, + _ => () + } + } + } } } @@ -399,7 +426,7 @@ impl ControlChannel { } }, _ = &mut self.shutdown_rx => { - info!( "Shutting down gracefully..."); + info!( "Control channel shutting down..."); break; } } @@ -433,6 +460,10 @@ impl ControlChannelHandle { .await .with_context(|| "Failed to run the control channel") { + if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) { + break; + } + let duration = Duration::from_secs(1); error!("{:?}\n\nRetry in {:?}...", err, duration); time::sleep(duration).await; diff --git a/src/config_watcher.rs b/src/config_watcher.rs index 2e262be..bca14c4 100644 --- a/src/config_watcher.rs +++ b/src/config_watcher.rs @@ -3,23 +3,26 @@ use crate::{ Config, }; use anyhow::{Context, Result}; -use notify::{EventKind, RecursiveMode, Watcher}; -use std::path::PathBuf; +use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher}; +use std::{ + collections::HashMap, + path::{Path, PathBuf}, +}; use tokio::sync::{broadcast, mpsc}; use tracing::{error, info, instrument}; #[derive(Debug)] pub enum ConfigChangeEvent { - General(Config), // Trigger a full restart + General(Box), // Trigger a full restart ServiceChange(ServiceChangeEvent), } #[derive(Debug)] pub enum ServiceChangeEvent { - AddClientService(ClientServiceConfig), - DeleteClientService(ClientServiceConfig), - AddServerService(ServerServiceConfig), - DeleteServerService(ServerServiceConfig), + ClientAdd(ClientServiceConfig), + ClientDelete(String), + ServerAdd(ServerServiceConfig), + ServerDelete(String), } pub struct ConfigWatcherHandle { @@ -27,7 +30,7 @@ pub struct ConfigWatcherHandle { } impl ConfigWatcherHandle { - pub async fn new(path: &PathBuf, shutdown_rx: broadcast::Receiver) -> Result { + pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver) -> Result { let (event_tx, event_rx) = mpsc::channel(16); let origin_cfg = Config::from_file(path).await?; @@ -43,7 +46,7 @@ impl ConfigWatcherHandle { } } -#[instrument(skip(shutdown_rx, cfg_event_tx))] +#[instrument(skip(shutdown_rx, cfg_event_tx, old))] async fn config_watcher( path: PathBuf, mut shutdown_rx: broadcast::Receiver, @@ -61,7 +64,7 @@ async fn config_watcher( // Initial start cfg_event_tx - .send(ConfigChangeEvent::General(old.clone())) + .send(ConfigChangeEvent::General(Box::new(old.clone()))) .await .unwrap(); @@ -73,9 +76,12 @@ async fn config_watcher( e = fevent_rx.recv() => { match e { Some(e) => { - match e.kind { - EventKind::Modify(_) => { - info!("Configuration modify event is detected"); + if let EventKind::Modify(kind) = e.kind { + match kind { + ModifyKind::Data(_) => (), + _ => continue + } + info!("Rescan the configuration"); let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") { Ok(v) => v, Err(e) => { @@ -90,9 +96,7 @@ async fn config_watcher( } old = new; - }, - _ => (), // Just ignore other events - } + } }, None => break } @@ -109,11 +113,71 @@ async fn config_watcher( fn calculate_event(old: &Config, new: &Config) -> Vec { let mut ret = Vec::new(); - if old == new { - return ret; - } + if old != new { + if old.server.is_some() && new.server.is_some() { + let mut e: Vec = calculate_service_delete_event( + &old.server.as_ref().unwrap().services, + &new.server.as_ref().unwrap().services, + ) + .into_iter() + .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x))) + .collect(); + ret.append(&mut e); - ret.push(ConfigChangeEvent::General(new.to_owned())); + let mut e: Vec = calculate_service_add_event( + &old.server.as_ref().unwrap().services, + &new.server.as_ref().unwrap().services, + ) + .into_iter() + .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x))) + .collect(); + + ret.append(&mut e); + } else if old.client.is_some() && new.client.is_some() { + let mut e: Vec = calculate_service_delete_event( + &old.client.as_ref().unwrap().services, + &new.client.as_ref().unwrap().services, + ) + .into_iter() + .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x))) + .collect(); + ret.append(&mut e); + + let mut e: Vec = calculate_service_add_event( + &old.client.as_ref().unwrap().services, + &new.client.as_ref().unwrap().services, + ) + .into_iter() + .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x))) + .collect(); + + ret.append(&mut e); + } else { + ret.push(ConfigChangeEvent::General(Box::new(new.clone()))); + } + } ret } + +fn calculate_service_delete_event( + old_services: &HashMap, + new_services: &HashMap, +) -> Vec { + old_services + .keys() + .filter(|&name| old_services.get(name) != new_services.get(name)) + .map(|x| x.to_owned()) + .collect() +} + +fn calculate_service_add_event( + old_services: &HashMap, + new_services: &HashMap, +) -> Vec { + new_services + .iter() + .filter(|(name, _)| old_services.get(*name) != new_services.get(*name)) + .map(|(_, c)| c.clone()) + .collect() +} diff --git a/src/lib.rs b/src/lib.rs index 3379d53..2dacfa9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ pub use constants::UDP_BUFFER_SIZE; use anyhow::Result; use tokio::sync::{broadcast, mpsc}; -use tracing::debug; +use tracing::{debug, info}; #[cfg(feature = "client")] mod client; @@ -82,12 +82,10 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() while let Some(e) = cfg_watcher.event_rx.recv().await { match e { ConfigChangeEvent::General(config) => { - match last_instance { - Some((i, _)) => { - shutdown_tx.send(true)?; - i.await??; - } - None => (), + if let Some((i, _)) = last_instance { + info!("General configuration change detected. Restarting..."); + shutdown_tx.send(true)?; + i.await??; } debug!("{:?}", config); @@ -96,7 +94,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() last_instance = Some(( tokio::spawn(run_instance( - config.clone(), + *(config.clone()), args.clone(), shutdown_tx.subscribe(), service_update_rx, @@ -105,6 +103,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() )); } ConfigChangeEvent::ServiceChange(service_event) => { + info!("Service change detcted. {:?}", service_event); if let Some((_, service_update_tx)) = &last_instance { let _ = service_update_tx.send(service_event).await; } @@ -118,7 +117,7 @@ async fn run_instance( config: Config, args: Cli, shutdown_rx: broadcast::Receiver, - _service_update: mpsc::Receiver, + service_update: mpsc::Receiver, ) -> Result<()> { match determine_run_mode(&config, &args) { RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), @@ -126,13 +125,13 @@ async fn run_instance( #[cfg(not(feature = "client"))] crate::helper::feature_not_compile("client"); #[cfg(feature = "client")] - run_client(&config, shutdown_rx).await + run_client(&config, shutdown_rx, service_update).await } RunMode::Server => { #[cfg(not(feature = "server"))] crate::helper::feature_not_compile("server"); #[cfg(feature = "server")] - run_server(&config, shutdown_rx).await + run_server(&config, shutdown_rx, service_update).await } } } diff --git a/src/server.rs b/src/server.rs index 7da2f98..f8bb2b2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType}; +use crate::config_watcher::ServiceChangeEvent; use crate::constants::{listen_backoff, UDP_BUFFER_SIZE}; use crate::multi_map::MultiMap; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; @@ -35,7 +36,11 @@ const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP serv const CHAN_SIZE: usize = 2048; // The capacity of various chans // The entrypoint of running a server -pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver) -> Result<()> { +pub async fn run_server( + config: &Config, + shutdown_rx: broadcast::Receiver, + service_rx: mpsc::Receiver, +) -> Result<()> { let config = match &config.server { Some(config) => config, None => { @@ -47,13 +52,13 @@ pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver) match config.transport.transport_type { TransportType::Tcp => { let mut server = Server::::from(config).await?; - server.run(shutdown_rx).await?; + server.run(shutdown_rx, service_rx).await?; } TransportType::Tls => { #[cfg(feature = "tls")] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx).await?; + server.run(shutdown_rx, service_rx).await?; } #[cfg(not(feature = "tls"))] crate::helper::feature_not_compile("tls") @@ -62,7 +67,7 @@ pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver) #[cfg(feature = "noise")] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx).await?; + server.run(shutdown_rx, service_rx).await?; } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -114,7 +119,11 @@ impl<'a, T: 'static + Transport> Server<'a, T> { } // The entry point of Server - pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver) -> Result<()> { + pub async fn run( + &mut self, + mut shutdown_rx: broadcast::Receiver, + mut service_rx: mpsc::Receiver, + ) -> Result<()> { // Listen at `server.bind_addr` let l = self .transport @@ -172,12 +181,38 @@ impl<'a, T: 'static + Transport> Server<'a, T> { _ = shutdown_rx.recv() => { info!("Shuting down gracefully..."); break; + }, + e = service_rx.recv() => { + if let Some(e) = e { + self.handle_hot_reload(e).await; + } } } } Ok(()) } + + async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) { + match e { + ServiceChangeEvent::ServerAdd(s) => { + let hash = protocol::digest(s.name.as_bytes()); + let mut wg = self.services.write().await; + let _ = wg.insert(hash, s); + + let mut wg = self.control_channels.write().await; + let _ = wg.remove1(&hash); + } + ServiceChangeEvent::ServerDelete(s) => { + let hash = protocol::digest(s.as_bytes()); + let _ = self.services.write().await.remove(&hash); + + let mut wg = self.control_channels.write().await; + let _ = wg.remove1(&hash); + } + _ => (), + } + } } // Handle connections to `server.bind_addr`