From c8cb60708d51f17d893defc587dc2165c109db2c Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Sun, 26 Dec 2021 22:59:12 +0800 Subject: [PATCH] test: refactor and add tests for hot-reload --- src/client.rs | 10 +- src/config.rs | 12 +- src/config_watcher.rs | 332 +++++++++++++++++++++++++++++++----------- src/lib.rs | 16 +- src/server.rs | 28 ++-- 5 files changed, 278 insertions(+), 120 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7591ab1..9f1bdcb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,5 @@ use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; -use crate::config_watcher::ServiceChangeEvent; +use crate::config_watcher::ServiceChange; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ @@ -30,7 +30,7 @@ use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT}; pub async fn run_client( config: &Config, shutdown_rx: broadcast::Receiver, - service_rx: mpsc::Receiver, + service_rx: mpsc::Receiver, ) -> Result<()> { let config = match &config.client { Some(v) => v, @@ -93,7 +93,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { async fn run( &mut self, mut shutdown_rx: broadcast::Receiver, - mut service_rx: mpsc::Receiver, + mut service_rx: mpsc::Receiver, ) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined @@ -120,7 +120,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { e = service_rx.recv() => { if let Some(e) = e { match e { - ServiceChangeEvent::ClientAdd(s)=> { + ServiceChange::ClientAdd(s)=> { let name = s.name.clone(); let handle = ControlChannelHandle::new( s, @@ -129,7 +129,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { ); let _ = self.service_handles.insert(name, handle); }, - ServiceChangeEvent::ClientDelete(s)=> { + ServiceChange::ClientDelete(s)=> { let _ = self.service_handles.remove(&s); }, _ => () diff --git a/src/config.rs b/src/config.rs index af09176..9fead6c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -38,11 +38,17 @@ pub enum ServiceType { Udp, } -fn default_service_type() -> ServiceType { - ServiceType::Tcp +impl Default for ServiceType { + fn default() -> Self { + ServiceType::Tcp + } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +fn default_service_type() -> ServiceType { + Default::default() +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] pub struct ServerServiceConfig { #[serde(rename = "type", default = "default_service_type")] pub service_type: ServiceType, diff --git a/src/config_watcher.rs b/src/config_watcher.rs index 79e84f4..c7c2ff5 100644 --- a/src/config_watcher.rs +++ b/src/config_watcher.rs @@ -1,5 +1,5 @@ use crate::{ - config::{ClientServiceConfig, ServerServiceConfig}, + config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig}, Config, }; use anyhow::{Context, Result}; @@ -13,22 +13,87 @@ use tracing::{error, info, instrument}; #[cfg(feature = "notify")] use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher}; -#[derive(Debug)] -pub enum ConfigChangeEvent { +#[derive(Debug, PartialEq)] +pub enum ConfigChange { General(Box), // Trigger a full restart - ServiceChange(ServiceChangeEvent), + ServiceChange(ServiceChange), } -#[derive(Debug)] -pub enum ServiceChangeEvent { +#[derive(Debug, PartialEq)] +pub enum ServiceChange { ClientAdd(ClientServiceConfig), ClientDelete(String), ServerAdd(ServerServiceConfig), ServerDelete(String), } +impl From for ServiceChange { + fn from(c: ClientServiceConfig) -> Self { + ServiceChange::ClientAdd(c) + } +} + +impl From for ServiceChange { + fn from(c: ServerServiceConfig) -> Self { + ServiceChange::ServerAdd(c) + } +} + +trait InstanceConfig: Clone { + type ServiceConfig: Into + PartialEq + Clone; + fn equal_without_service(&self, rhs: &Self) -> bool; + fn to_service_change_delete(s: String) -> ServiceChange; + fn get_services(&self) -> &HashMap; +} + +impl InstanceConfig for ServerConfig { + type ServiceConfig = ServerServiceConfig; + fn equal_without_service(&self, rhs: &Self) -> bool { + let left = ServerConfig { + services: Default::default(), + ..self.clone() + }; + + let right = ServerConfig { + services: Default::default(), + ..rhs.clone() + }; + + left == right + } + fn to_service_change_delete(s: String) -> ServiceChange { + ServiceChange::ServerDelete(s) + } + fn get_services(&self) -> &HashMap { + &self.services + } +} + +impl InstanceConfig for ClientConfig { + type ServiceConfig = ClientServiceConfig; + fn equal_without_service(&self, rhs: &Self) -> bool { + let left = ClientConfig { + services: Default::default(), + ..self.clone() + }; + + let right = ClientConfig { + services: Default::default(), + ..rhs.clone() + }; + + left == right + } + fn to_service_change_delete(s: String) -> ServiceChange { + ServiceChange::ClientDelete(s) + } + fn get_services(&self) -> &HashMap { + &self.services + } +} + pub struct ConfigWatcherHandle { - pub event_rx: mpsc::Receiver, + pub event_rx: mpsc::Receiver, } impl ConfigWatcherHandle { @@ -39,7 +104,7 @@ impl ConfigWatcherHandle { // Initial start event_tx - .send(ConfigChangeEvent::General(Box::new(origin_cfg.clone()))) + .send(ConfigChange::General(Box::new(origin_cfg.clone()))) .await .unwrap(); @@ -59,30 +124,33 @@ impl ConfigWatcherHandle { async fn config_watcher( _path: PathBuf, mut shutdown_rx: broadcast::Receiver, - _cfg_event_tx: mpsc::Sender, + _event_tx: mpsc::Sender, _old: Config, ) -> Result<()> { - // Do nothing except wating for ctrl-c + // Do nothing except waiting for ctrl-c let _ = shutdown_rx.recv().await; Ok(()) } #[cfg(feature = "notify")] -#[instrument(skip(shutdown_rx, cfg_event_tx, old))] +#[instrument(skip(shutdown_rx, event_tx, old))] async fn config_watcher( path: PathBuf, mut shutdown_rx: broadcast::Receiver, - cfg_event_tx: mpsc::Sender, + event_tx: mpsc::Sender, mut old: Config, ) -> Result<()> { let (fevent_tx, mut fevent_rx) = mpsc::channel(16); - let mut watcher = notify::recommended_watcher(move |res| match res { - Ok(event) => { - let _ = fevent_tx.blocking_send(event); - } - Err(e) => error!("watch error: {:?}", e), - })?; + let mut watcher = + notify::recommended_watcher(move |res: Result| match res { + Ok(e) => { + if let EventKind::Modify(ModifyKind::Data(_)) = e.kind { + let _ = fevent_tx.blocking_send(true); + } + } + Err(e) => error!("watch error: {:?}", e), + })?; watcher.watch(&path, RecursiveMode::NonRecursive)?; info!("Start watching the config"); @@ -91,12 +159,7 @@ async fn config_watcher( tokio::select! { e = fevent_rx.recv() => { match e { - Some(e) => { - if let EventKind::Modify(kind) = e.kind { - match kind { - ModifyKind::Data(_) => (), - _ => continue - } + Some(_) => { info!("Rescan the configuration"); let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") { Ok(v) => v, @@ -107,12 +170,11 @@ async fn config_watcher( } }; - for event in calculate_event(&old, &new) { - cfg_event_tx.send(event).await?; + for event in calculate_events(&old, &new) { + event_tx.send(event).await?; } old = new; - } }, None => break } @@ -126,74 +188,170 @@ async fn config_watcher( Ok(()) } -fn calculate_event(old: &Config, new: &Config) -> Vec { - let mut ret = Vec::new(); - - if old != new { - if old.server.is_some() && new.server.is_some() { - let mut e: Vec = calculate_service_delete_event( - &old.server.as_ref().unwrap().services, - &new.server.as_ref().unwrap().services, - ) - .into_iter() - .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x))) - .collect(); - ret.append(&mut e); - - let mut e: Vec = calculate_service_add_event( - &old.server.as_ref().unwrap().services, - &new.server.as_ref().unwrap().services, - ) - .into_iter() - .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x))) - .collect(); - - ret.append(&mut e); - } else if old.client.is_some() && new.client.is_some() { - let mut e: Vec = calculate_service_delete_event( - &old.client.as_ref().unwrap().services, - &new.client.as_ref().unwrap().services, - ) - .into_iter() - .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x))) - .collect(); - ret.append(&mut e); - - let mut e: Vec = calculate_service_add_event( - &old.client.as_ref().unwrap().services, - &new.client.as_ref().unwrap().services, - ) - .into_iter() - .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x))) - .collect(); - - ret.append(&mut e); +fn calculate_events(old: &Config, new: &Config) -> Vec { + if old == new { + vec![] + } else if old.server != new.server { + if old.server.is_some() != new.server.is_some() { + vec![ConfigChange::General(Box::new(new.clone()))] } else { - ret.push(ConfigChangeEvent::General(Box::new(new.clone()))); + match calculate_instance_config_events( + old.server.as_ref().unwrap(), + new.server.as_ref().unwrap(), + ) { + Some(v) => v, + None => vec![ConfigChange::General(Box::new(new.clone()))], + } } + } else if old.client != new.client { + if old.client.is_some() != new.client.is_some() { + vec![ConfigChange::General(Box::new(new.clone()))] + } else { + match calculate_instance_config_events( + old.client.as_ref().unwrap(), + new.client.as_ref().unwrap(), + ) { + Some(v) => v, + None => vec![ConfigChange::General(Box::new(new.clone()))], + } + } + } else { + vec![] + } +} + +// None indicates a General change needed +fn calculate_instance_config_events( + old: &T, + new: &T, +) -> Option> { + if !old.equal_without_service(new) { + return None; } - ret + let old = old.get_services(); + let new = new.get_services(); + + let mut v = vec![]; + v.append(&mut calculate_service_delete_events::(old, new)); + v.append(&mut calculate_service_add_events(old, new)); + + Some(v.into_iter().map(ConfigChange::ServiceChange).collect()) } -fn calculate_service_delete_event( - old_services: &HashMap, - new_services: &HashMap, -) -> Vec { - old_services - .keys() - .filter(|&name| old_services.get(name) != new_services.get(name)) - .map(|x| x.to_owned()) +fn calculate_service_delete_events( + old: &HashMap, + new: &HashMap, +) -> Vec { + old.keys() + .filter(|&name| new.get(name).is_none()) + .map(|x| T::to_service_change_delete(x.to_owned())) .collect() } -fn calculate_service_add_event( - old_services: &HashMap, - new_services: &HashMap, -) -> Vec { - new_services - .iter() - .filter(|(name, _)| old_services.get(*name) != new_services.get(*name)) - .map(|(_, c)| c.clone()) +fn calculate_service_add_events>( + old: &HashMap, + new: &HashMap, +) -> Vec { + new.iter() + .filter(|(name, c)| old.get(*name) != Some(*c)) + .map(|(_, c)| c.clone().into()) .collect() } + +#[cfg(test)] +mod test { + use crate::config::ServerConfig; + + use super::*; + + // macro to create map or set literal + macro_rules! collection { + // map-like + ($($k:expr => $v:expr),* $(,)?) => {{ + use std::iter::{Iterator, IntoIterator}; + Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*])) + }}; + } + + #[test] + fn test_calculate_events() { + struct Test { + old: Config, + new: Config, + } + + let tests = [ + Test { + old: Config { + server: Some(Default::default()), + client: None, + }, + new: Config { + server: Some(Default::default()), + client: Some(Default::default()), + }, + }, + Test { + old: Config { + server: Some(ServerConfig { + bind_addr: String::from("127.0.0.1:2334"), + ..Default::default() + }), + client: None, + }, + new: Config { + server: Some(ServerConfig { + bind_addr: String::from("127.0.0.1:2333"), + services: collection!(String::from("foo") => Default::default()), + ..Default::default() + }), + client: None, + }, + }, + Test { + old: Config { + server: Some(Default::default()), + client: None, + }, + new: Config { + server: Some(ServerConfig { + services: collection!(String::from("foo") => Default::default()), + ..Default::default() + }), + client: None, + }, + }, + Test { + old: Config { + server: Some(ServerConfig { + services: collection!(String::from("foo") => Default::default()), + ..Default::default() + }), + client: None, + }, + new: Config { + server: Some(Default::default()), + client: None, + }, + }, + ]; + let expected = [ + vec![ConfigChange::General(Box::new(tests[0].new.clone()))], + vec![ConfigChange::General(Box::new(tests[1].new.clone()))], + vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd( + Default::default(), + ))], + vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete( + String::from("foo"), + ))], + ]; + + assert_eq!(tests.len(), expected.len()); + + for i in 0..tests.len() { + let actual = calculate_events(&tests[i].old, &tests[i].new); + assert_eq!(actual, expected[i]); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2dacfa9..2097c16 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod transport; pub use cli::Cli; use cli::KeypairType; pub use config::Config; -use config_watcher::ServiceChangeEvent; +use config_watcher::ServiceChange; pub use constants::UDP_BUFFER_SIZE; use anyhow::Result; @@ -27,7 +27,7 @@ mod server; #[cfg(feature = "server")] use server::run_server; -use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle}; +use crate::config_watcher::{ConfigChange, ConfigWatcherHandle}; const DEFAULT_CURVE: KeypairType = KeypairType::X25519; @@ -76,12 +76,11 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() let (shutdown_tx, _) = broadcast::channel(1); // (The join handle of the last instance, The service update channel sender) - let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = - None; + let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = None; while let Some(e) = cfg_watcher.event_rx.recv().await { match e { - ConfigChangeEvent::General(config) => { + ConfigChange::General(config) => { if let Some((i, _)) = last_instance { info!("General configuration change detected. Restarting..."); shutdown_tx.send(true)?; @@ -102,7 +101,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() service_update_tx, )); } - ConfigChangeEvent::ServiceChange(service_event) => { + ConfigChange::ServiceChange(service_event) => { info!("Service change detcted. {:?}", service_event); if let Some((_, service_update_tx)) = &last_instance { let _ = service_update_tx.send(service_event).await; @@ -110,6 +109,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() } } } + + let _ = shutdown_tx.send(true); + Ok(()) } @@ -117,7 +119,7 @@ async fn run_instance( config: Config, args: Cli, shutdown_rx: broadcast::Receiver, - service_update: mpsc::Receiver, + service_update: mpsc::Receiver, ) -> Result<()> { match determine_run_mode(&config, &args) { RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), diff --git a/src/server.rs b/src/server.rs index f8bb2b2..07b3e6a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,5 @@ use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType}; -use crate::config_watcher::ServiceChangeEvent; +use crate::config_watcher::ServiceChange; use crate::constants::{listen_backoff, UDP_BUFFER_SIZE}; use crate::multi_map::MultiMap; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; @@ -39,7 +39,7 @@ const CHAN_SIZE: usize = 2048; // The capacity of various chans pub async fn run_server( config: &Config, shutdown_rx: broadcast::Receiver, - service_rx: mpsc::Receiver, + service_rx: mpsc::Receiver, ) -> Result<()> { let config = match &config.server { Some(config) => config, @@ -122,7 +122,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> { pub async fn run( &mut self, mut shutdown_rx: broadcast::Receiver, - mut service_rx: mpsc::Receiver, + mut service_rx: mpsc::Receiver, ) -> Result<()> { // Listen at `server.bind_addr` let l = self @@ -193,9 +193,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> { Ok(()) } - async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) { + async fn handle_hot_reload(&mut self, e: ServiceChange) { match e { - ServiceChangeEvent::ServerAdd(s) => { + ServiceChange::ServerAdd(s) => { let hash = protocol::digest(s.name.as_bytes()); let mut wg = self.services.write().await; let _ = wg.insert(hash, s); @@ -203,7 +203,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> { let mut wg = self.control_channels.write().await; let _ = wg.remove1(&hash); } - ServiceChangeEvent::ServerDelete(s) => { + ServiceChange::ServerDelete(s) => { let hash = protocol::digest(s.as_bytes()); let _ = self.services.write().await.remove(&hash); @@ -340,11 +340,8 @@ async fn do_data_channel_handshake( } pub struct ControlChannelHandle { - // Shutdown the control channel. - // Not used for now, but can be used for hot reloading - #[allow(dead_code)] - shutdown_tx: broadcast::Sender, - //data_ch_req_tx: mpsc::Sender, + // Shutdown the control channel by dropping it + _shutdown_tx: broadcast::Sender, data_ch_tx: mpsc::Sender, } @@ -359,7 +356,7 @@ where // Save the name string for logging let name = service.name.clone(); - // Create a shutdown channel. The sender is not used for now, but for future use + // Create a shutdown channel let (shutdown_tx, shutdown_rx) = broadcast::channel::(1); // Store data channels @@ -417,15 +414,10 @@ where }); ControlChannelHandle { - shutdown_tx, + _shutdown_tx: shutdown_tx, data_ch_tx, } } - - #[allow(dead_code)] - fn shutdown(self) { - let _ = self.shutdown_tx.send(true); - } } // Control channel, using T as the transport layer. P is TcpStream or UdpTraffic