test: refactor and add tests for hot-reload

This commit is contained in:
Yujia Qiao 2021-12-26 22:59:12 +08:00 committed by Yujia Qiao
parent c8e679fa65
commit c8cb60708d
5 changed files with 278 additions and 120 deletions

View File

@ -1,5 +1,5 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
use crate::config_watcher::ServiceChangeEvent; use crate::config_watcher::ServiceChange;
use crate::helper::udp_connect; use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *}; use crate::protocol::Hello::{self, *};
use crate::protocol::{ use crate::protocol::{
@ -30,7 +30,7 @@ use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
pub async fn run_client( pub async fn run_client(
config: &Config, config: &Config,
shutdown_rx: broadcast::Receiver<bool>, shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChangeEvent>, service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> { ) -> Result<()> {
let config = match &config.client { let config = match &config.client {
Some(v) => v, Some(v) => v,
@ -93,7 +93,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
async fn run( async fn run(
&mut self, &mut self,
mut shutdown_rx: broadcast::Receiver<bool>, mut shutdown_rx: broadcast::Receiver<bool>,
mut service_rx: mpsc::Receiver<ServiceChangeEvent>, mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> { ) -> Result<()> {
for (name, config) in &self.config.services { for (name, config) in &self.config.services {
// Create a control channel for each service defined // Create a control channel for each service defined
@ -120,7 +120,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
e = service_rx.recv() => { e = service_rx.recv() => {
if let Some(e) = e { if let Some(e) = e {
match e { match e {
ServiceChangeEvent::ClientAdd(s)=> { ServiceChange::ClientAdd(s)=> {
let name = s.name.clone(); let name = s.name.clone();
let handle = ControlChannelHandle::new( let handle = ControlChannelHandle::new(
s, s,
@ -129,7 +129,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
); );
let _ = self.service_handles.insert(name, handle); let _ = self.service_handles.insert(name, handle);
}, },
ServiceChangeEvent::ClientDelete(s)=> { ServiceChange::ClientDelete(s)=> {
let _ = self.service_handles.remove(&s); let _ = self.service_handles.remove(&s);
}, },
_ => () _ => ()

View File

@ -38,11 +38,17 @@ pub enum ServiceType {
Udp, Udp,
} }
fn default_service_type() -> ServiceType { impl Default for ServiceType {
ServiceType::Tcp fn default() -> Self {
ServiceType::Tcp
}
} }
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] fn default_service_type() -> ServiceType {
Default::default()
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
pub struct ServerServiceConfig { pub struct ServerServiceConfig {
#[serde(rename = "type", default = "default_service_type")] #[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType, pub service_type: ServiceType,

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
config::{ClientServiceConfig, ServerServiceConfig}, config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig},
Config, Config,
}; };
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@ -13,22 +13,87 @@ use tracing::{error, info, instrument};
#[cfg(feature = "notify")] #[cfg(feature = "notify")]
use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher}; use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher};
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum ConfigChangeEvent { pub enum ConfigChange {
General(Box<Config>), // Trigger a full restart General(Box<Config>), // Trigger a full restart
ServiceChange(ServiceChangeEvent), ServiceChange(ServiceChange),
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum ServiceChangeEvent { pub enum ServiceChange {
ClientAdd(ClientServiceConfig), ClientAdd(ClientServiceConfig),
ClientDelete(String), ClientDelete(String),
ServerAdd(ServerServiceConfig), ServerAdd(ServerServiceConfig),
ServerDelete(String), ServerDelete(String),
} }
impl From<ClientServiceConfig> for ServiceChange {
fn from(c: ClientServiceConfig) -> Self {
ServiceChange::ClientAdd(c)
}
}
impl From<ServerServiceConfig> for ServiceChange {
fn from(c: ServerServiceConfig) -> Self {
ServiceChange::ServerAdd(c)
}
}
trait InstanceConfig: Clone {
type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
fn equal_without_service(&self, rhs: &Self) -> bool;
fn to_service_change_delete(s: String) -> ServiceChange;
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
}
impl InstanceConfig for ServerConfig {
type ServiceConfig = ServerServiceConfig;
fn equal_without_service(&self, rhs: &Self) -> bool {
let left = ServerConfig {
services: Default::default(),
..self.clone()
};
let right = ServerConfig {
services: Default::default(),
..rhs.clone()
};
left == right
}
fn to_service_change_delete(s: String) -> ServiceChange {
ServiceChange::ServerDelete(s)
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
}
}
impl InstanceConfig for ClientConfig {
type ServiceConfig = ClientServiceConfig;
fn equal_without_service(&self, rhs: &Self) -> bool {
let left = ClientConfig {
services: Default::default(),
..self.clone()
};
let right = ClientConfig {
services: Default::default(),
..rhs.clone()
};
left == right
}
fn to_service_change_delete(s: String) -> ServiceChange {
ServiceChange::ClientDelete(s)
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
}
}
pub struct ConfigWatcherHandle { pub struct ConfigWatcherHandle {
pub event_rx: mpsc::Receiver<ConfigChangeEvent>, pub event_rx: mpsc::Receiver<ConfigChange>,
} }
impl ConfigWatcherHandle { impl ConfigWatcherHandle {
@ -39,7 +104,7 @@ impl ConfigWatcherHandle {
// Initial start // Initial start
event_tx event_tx
.send(ConfigChangeEvent::General(Box::new(origin_cfg.clone()))) .send(ConfigChange::General(Box::new(origin_cfg.clone())))
.await .await
.unwrap(); .unwrap();
@ -59,30 +124,33 @@ impl ConfigWatcherHandle {
async fn config_watcher( async fn config_watcher(
_path: PathBuf, _path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>, mut shutdown_rx: broadcast::Receiver<bool>,
_cfg_event_tx: mpsc::Sender<ConfigChangeEvent>, _event_tx: mpsc::Sender<ConfigChange>,
_old: Config, _old: Config,
) -> Result<()> { ) -> Result<()> {
// Do nothing except wating for ctrl-c // Do nothing except waiting for ctrl-c
let _ = shutdown_rx.recv().await; let _ = shutdown_rx.recv().await;
Ok(()) Ok(())
} }
#[cfg(feature = "notify")] #[cfg(feature = "notify")]
#[instrument(skip(shutdown_rx, cfg_event_tx, old))] #[instrument(skip(shutdown_rx, event_tx, old))]
async fn config_watcher( async fn config_watcher(
path: PathBuf, path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>, mut shutdown_rx: broadcast::Receiver<bool>,
cfg_event_tx: mpsc::Sender<ConfigChangeEvent>, event_tx: mpsc::Sender<ConfigChange>,
mut old: Config, mut old: Config,
) -> Result<()> { ) -> Result<()> {
let (fevent_tx, mut fevent_rx) = mpsc::channel(16); let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
let mut watcher = notify::recommended_watcher(move |res| match res { let mut watcher =
Ok(event) => { notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
let _ = fevent_tx.blocking_send(event); Ok(e) => {
} if let EventKind::Modify(ModifyKind::Data(_)) = e.kind {
Err(e) => error!("watch error: {:?}", e), let _ = fevent_tx.blocking_send(true);
})?; }
}
Err(e) => error!("watch error: {:?}", e),
})?;
watcher.watch(&path, RecursiveMode::NonRecursive)?; watcher.watch(&path, RecursiveMode::NonRecursive)?;
info!("Start watching the config"); info!("Start watching the config");
@ -91,12 +159,7 @@ async fn config_watcher(
tokio::select! { tokio::select! {
e = fevent_rx.recv() => { e = fevent_rx.recv() => {
match e { match e {
Some(e) => { Some(_) => {
if let EventKind::Modify(kind) = e.kind {
match kind {
ModifyKind::Data(_) => (),
_ => continue
}
info!("Rescan the configuration"); info!("Rescan the configuration");
let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") { let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
Ok(v) => v, Ok(v) => v,
@ -107,12 +170,11 @@ async fn config_watcher(
} }
}; };
for event in calculate_event(&old, &new) { for event in calculate_events(&old, &new) {
cfg_event_tx.send(event).await?; event_tx.send(event).await?;
} }
old = new; old = new;
}
}, },
None => break None => break
} }
@ -126,74 +188,170 @@ async fn config_watcher(
Ok(()) Ok(())
} }
fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> { fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
let mut ret = Vec::new(); if old == new {
vec![]
if old != new { } else if old.server != new.server {
if old.server.is_some() && new.server.is_some() { if old.server.is_some() != new.server.is_some() {
let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event( vec![ConfigChange::General(Box::new(new.clone()))]
&old.server.as_ref().unwrap().services,
&new.server.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x)))
.collect();
ret.append(&mut e);
let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
&old.server.as_ref().unwrap().services,
&new.server.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x)))
.collect();
ret.append(&mut e);
} else if old.client.is_some() && new.client.is_some() {
let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
&old.client.as_ref().unwrap().services,
&new.client.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x)))
.collect();
ret.append(&mut e);
let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
&old.client.as_ref().unwrap().services,
&new.client.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x)))
.collect();
ret.append(&mut e);
} else { } else {
ret.push(ConfigChangeEvent::General(Box::new(new.clone()))); match calculate_instance_config_events(
old.server.as_ref().unwrap(),
new.server.as_ref().unwrap(),
) {
Some(v) => v,
None => vec![ConfigChange::General(Box::new(new.clone()))],
}
} }
} else if old.client != new.client {
if old.client.is_some() != new.client.is_some() {
vec![ConfigChange::General(Box::new(new.clone()))]
} else {
match calculate_instance_config_events(
old.client.as_ref().unwrap(),
new.client.as_ref().unwrap(),
) {
Some(v) => v,
None => vec![ConfigChange::General(Box::new(new.clone()))],
}
}
} else {
vec![]
}
}
// None indicates a General change needed
fn calculate_instance_config_events<T: InstanceConfig>(
old: &T,
new: &T,
) -> Option<Vec<ConfigChange>> {
if !old.equal_without_service(new) {
return None;
} }
ret let old = old.get_services();
let new = new.get_services();
let mut v = vec![];
v.append(&mut calculate_service_delete_events::<T>(old, new));
v.append(&mut calculate_service_add_events(old, new));
Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
} }
fn calculate_service_delete_event<T: PartialEq>( fn calculate_service_delete_events<T: InstanceConfig>(
old_services: &HashMap<String, T>, old: &HashMap<String, T::ServiceConfig>,
new_services: &HashMap<String, T>, new: &HashMap<String, T::ServiceConfig>,
) -> Vec<String> { ) -> Vec<ServiceChange> {
old_services old.keys()
.keys() .filter(|&name| new.get(name).is_none())
.filter(|&name| old_services.get(name) != new_services.get(name)) .map(|x| T::to_service_change_delete(x.to_owned()))
.map(|x| x.to_owned())
.collect() .collect()
} }
fn calculate_service_add_event<T: PartialEq + Clone>( fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
old_services: &HashMap<String, T>, old: &HashMap<String, T>,
new_services: &HashMap<String, T>, new: &HashMap<String, T>,
) -> Vec<T> { ) -> Vec<ServiceChange> {
new_services new.iter()
.iter() .filter(|(name, c)| old.get(*name) != Some(*c))
.filter(|(name, _)| old_services.get(*name) != new_services.get(*name)) .map(|(_, c)| c.clone().into())
.map(|(_, c)| c.clone())
.collect() .collect()
} }
#[cfg(test)]
mod test {
use crate::config::ServerConfig;
use super::*;
// macro to create map or set literal
macro_rules! collection {
// map-like
($($k:expr => $v:expr),* $(,)?) => {{
use std::iter::{Iterator, IntoIterator};
Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*]))
}};
}
#[test]
fn test_calculate_events() {
struct Test {
old: Config,
new: Config,
}
let tests = [
Test {
old: Config {
server: Some(Default::default()),
client: None,
},
new: Config {
server: Some(Default::default()),
client: Some(Default::default()),
},
},
Test {
old: Config {
server: Some(ServerConfig {
bind_addr: String::from("127.0.0.1:2334"),
..Default::default()
}),
client: None,
},
new: Config {
server: Some(ServerConfig {
bind_addr: String::from("127.0.0.1:2333"),
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
},
Test {
old: Config {
server: Some(Default::default()),
client: None,
},
new: Config {
server: Some(ServerConfig {
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
},
Test {
old: Config {
server: Some(ServerConfig {
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
new: Config {
server: Some(Default::default()),
client: None,
},
},
];
let expected = [
vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
Default::default(),
))],
vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
String::from("foo"),
))],
];
assert_eq!(tests.len(), expected.len());
for i in 0..tests.len() {
let actual = calculate_events(&tests[i].old, &tests[i].new);
assert_eq!(actual, expected[i]);
}
}
}

View File

@ -10,7 +10,7 @@ mod transport;
pub use cli::Cli; pub use cli::Cli;
use cli::KeypairType; use cli::KeypairType;
pub use config::Config; pub use config::Config;
use config_watcher::ServiceChangeEvent; use config_watcher::ServiceChange;
pub use constants::UDP_BUFFER_SIZE; pub use constants::UDP_BUFFER_SIZE;
use anyhow::Result; use anyhow::Result;
@ -27,7 +27,7 @@ mod server;
#[cfg(feature = "server")] #[cfg(feature = "server")]
use server::run_server; use server::run_server;
use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle}; use crate::config_watcher::{ConfigChange, ConfigWatcherHandle};
const DEFAULT_CURVE: KeypairType = KeypairType::X25519; const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
@ -76,12 +76,11 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
let (shutdown_tx, _) = broadcast::channel(1); let (shutdown_tx, _) = broadcast::channel(1);
// (The join handle of the last instance, The service update channel sender) // (The join handle of the last instance, The service update channel sender)
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> = let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
None;
while let Some(e) = cfg_watcher.event_rx.recv().await { while let Some(e) = cfg_watcher.event_rx.recv().await {
match e { match e {
ConfigChangeEvent::General(config) => { ConfigChange::General(config) => {
if let Some((i, _)) = last_instance { if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting..."); info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?; shutdown_tx.send(true)?;
@ -102,7 +101,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
service_update_tx, service_update_tx,
)); ));
} }
ConfigChangeEvent::ServiceChange(service_event) => { ConfigChange::ServiceChange(service_event) => {
info!("Service change detcted. {:?}", service_event); info!("Service change detcted. {:?}", service_event);
if let Some((_, service_update_tx)) = &last_instance { if let Some((_, service_update_tx)) = &last_instance {
let _ = service_update_tx.send(service_event).await; let _ = service_update_tx.send(service_event).await;
@ -110,6 +109,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
} }
} }
} }
let _ = shutdown_tx.send(true);
Ok(()) Ok(())
} }
@ -117,7 +119,7 @@ async fn run_instance(
config: Config, config: Config,
args: Cli, args: Cli,
shutdown_rx: broadcast::Receiver<bool>, shutdown_rx: broadcast::Receiver<bool>,
service_update: mpsc::Receiver<ServiceChangeEvent>, service_update: mpsc::Receiver<ServiceChange>,
) -> Result<()> { ) -> Result<()> {
match determine_run_mode(&config, &args) { match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),

View File

@ -1,5 +1,5 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType}; use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::config_watcher::ServiceChangeEvent; use crate::config_watcher::ServiceChange;
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE}; use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::multi_map::MultiMap; use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
@ -39,7 +39,7 @@ const CHAN_SIZE: usize = 2048; // The capacity of various chans
pub async fn run_server( pub async fn run_server(
config: &Config, config: &Config,
shutdown_rx: broadcast::Receiver<bool>, shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChangeEvent>, service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> { ) -> Result<()> {
let config = match &config.server { let config = match &config.server {
Some(config) => config, Some(config) => config,
@ -122,7 +122,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
pub async fn run( pub async fn run(
&mut self, &mut self,
mut shutdown_rx: broadcast::Receiver<bool>, mut shutdown_rx: broadcast::Receiver<bool>,
mut service_rx: mpsc::Receiver<ServiceChangeEvent>, mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> { ) -> Result<()> {
// Listen at `server.bind_addr` // Listen at `server.bind_addr`
let l = self let l = self
@ -193,9 +193,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
Ok(()) Ok(())
} }
async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) { async fn handle_hot_reload(&mut self, e: ServiceChange) {
match e { match e {
ServiceChangeEvent::ServerAdd(s) => { ServiceChange::ServerAdd(s) => {
let hash = protocol::digest(s.name.as_bytes()); let hash = protocol::digest(s.name.as_bytes());
let mut wg = self.services.write().await; let mut wg = self.services.write().await;
let _ = wg.insert(hash, s); let _ = wg.insert(hash, s);
@ -203,7 +203,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
let mut wg = self.control_channels.write().await; let mut wg = self.control_channels.write().await;
let _ = wg.remove1(&hash); let _ = wg.remove1(&hash);
} }
ServiceChangeEvent::ServerDelete(s) => { ServiceChange::ServerDelete(s) => {
let hash = protocol::digest(s.as_bytes()); let hash = protocol::digest(s.as_bytes());
let _ = self.services.write().await.remove(&hash); let _ = self.services.write().await.remove(&hash);
@ -340,11 +340,8 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
} }
pub struct ControlChannelHandle<T: Transport> { pub struct ControlChannelHandle<T: Transport> {
// Shutdown the control channel. // Shutdown the control channel by dropping it
// Not used for now, but can be used for hot reloading _shutdown_tx: broadcast::Sender<bool>,
#[allow(dead_code)]
shutdown_tx: broadcast::Sender<bool>,
//data_ch_req_tx: mpsc::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>, data_ch_tx: mpsc::Sender<T::Stream>,
} }
@ -359,7 +356,7 @@ where
// Save the name string for logging // Save the name string for logging
let name = service.name.clone(); let name = service.name.clone();
// Create a shutdown channel. The sender is not used for now, but for future use // Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1); let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
// Store data channels // Store data channels
@ -417,15 +414,10 @@ where
}); });
ControlChannelHandle { ControlChannelHandle {
shutdown_tx, _shutdown_tx: shutdown_tx,
data_ch_tx, data_ch_tx,
} }
} }
#[allow(dead_code)]
fn shutdown(self) {
let _ = self.shutdown_tx.send(true);
}
} }
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic // Control channel, using T as the transport layer. P is TcpStream or UdpTraffic