test: refactor and add tests for hot-reload

This commit is contained in:
Yujia Qiao 2021-12-26 22:59:12 +08:00 committed by Yujia Qiao
parent c8e679fa65
commit c8cb60708d
5 changed files with 278 additions and 120 deletions

View File

@ -1,5 +1,5 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
use crate::config_watcher::ServiceChangeEvent;
use crate::config_watcher::ServiceChange;
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
use crate::protocol::{
@ -30,7 +30,7 @@ use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
pub async fn run_client(
config: &Config,
shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChangeEvent>,
service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.client {
Some(v) => v,
@ -93,7 +93,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
@ -120,7 +120,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
e = service_rx.recv() => {
if let Some(e) = e {
match e {
ServiceChangeEvent::ClientAdd(s)=> {
ServiceChange::ClientAdd(s)=> {
let name = s.name.clone();
let handle = ControlChannelHandle::new(
s,
@ -129,7 +129,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
);
let _ = self.service_handles.insert(name, handle);
},
ServiceChangeEvent::ClientDelete(s)=> {
ServiceChange::ClientDelete(s)=> {
let _ = self.service_handles.remove(&s);
},
_ => ()

View File

@ -38,11 +38,17 @@ pub enum ServiceType {
Udp,
}
fn default_service_type() -> ServiceType {
ServiceType::Tcp
impl Default for ServiceType {
fn default() -> Self {
ServiceType::Tcp
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
fn default_service_type() -> ServiceType {
Default::default()
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
pub struct ServerServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,

View File

@ -1,5 +1,5 @@
use crate::{
config::{ClientServiceConfig, ServerServiceConfig},
config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig},
Config,
};
use anyhow::{Context, Result};
@ -13,22 +13,87 @@ use tracing::{error, info, instrument};
#[cfg(feature = "notify")]
use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher};
#[derive(Debug)]
pub enum ConfigChangeEvent {
#[derive(Debug, PartialEq)]
pub enum ConfigChange {
General(Box<Config>), // Trigger a full restart
ServiceChange(ServiceChangeEvent),
ServiceChange(ServiceChange),
}
#[derive(Debug)]
pub enum ServiceChangeEvent {
#[derive(Debug, PartialEq)]
pub enum ServiceChange {
ClientAdd(ClientServiceConfig),
ClientDelete(String),
ServerAdd(ServerServiceConfig),
ServerDelete(String),
}
impl From<ClientServiceConfig> for ServiceChange {
fn from(c: ClientServiceConfig) -> Self {
ServiceChange::ClientAdd(c)
}
}
impl From<ServerServiceConfig> for ServiceChange {
fn from(c: ServerServiceConfig) -> Self {
ServiceChange::ServerAdd(c)
}
}
trait InstanceConfig: Clone {
type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
fn equal_without_service(&self, rhs: &Self) -> bool;
fn to_service_change_delete(s: String) -> ServiceChange;
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
}
impl InstanceConfig for ServerConfig {
type ServiceConfig = ServerServiceConfig;
fn equal_without_service(&self, rhs: &Self) -> bool {
let left = ServerConfig {
services: Default::default(),
..self.clone()
};
let right = ServerConfig {
services: Default::default(),
..rhs.clone()
};
left == right
}
fn to_service_change_delete(s: String) -> ServiceChange {
ServiceChange::ServerDelete(s)
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
}
}
impl InstanceConfig for ClientConfig {
type ServiceConfig = ClientServiceConfig;
fn equal_without_service(&self, rhs: &Self) -> bool {
let left = ClientConfig {
services: Default::default(),
..self.clone()
};
let right = ClientConfig {
services: Default::default(),
..rhs.clone()
};
left == right
}
fn to_service_change_delete(s: String) -> ServiceChange {
ServiceChange::ClientDelete(s)
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
}
}
pub struct ConfigWatcherHandle {
pub event_rx: mpsc::Receiver<ConfigChangeEvent>,
pub event_rx: mpsc::Receiver<ConfigChange>,
}
impl ConfigWatcherHandle {
@ -39,7 +104,7 @@ impl ConfigWatcherHandle {
// Initial start
event_tx
.send(ConfigChangeEvent::General(Box::new(origin_cfg.clone())))
.send(ConfigChange::General(Box::new(origin_cfg.clone())))
.await
.unwrap();
@ -59,30 +124,33 @@ impl ConfigWatcherHandle {
async fn config_watcher(
_path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
_cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
_event_tx: mpsc::Sender<ConfigChange>,
_old: Config,
) -> Result<()> {
// Do nothing except wating for ctrl-c
// Do nothing except waiting for ctrl-c
let _ = shutdown_rx.recv().await;
Ok(())
}
#[cfg(feature = "notify")]
#[instrument(skip(shutdown_rx, cfg_event_tx, old))]
#[instrument(skip(shutdown_rx, event_tx, old))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
event_tx: mpsc::Sender<ConfigChange>,
mut old: Config,
) -> Result<()> {
let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
let mut watcher = notify::recommended_watcher(move |res| match res {
Ok(event) => {
let _ = fevent_tx.blocking_send(event);
}
Err(e) => error!("watch error: {:?}", e),
})?;
let mut watcher =
notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
Ok(e) => {
if let EventKind::Modify(ModifyKind::Data(_)) = e.kind {
let _ = fevent_tx.blocking_send(true);
}
}
Err(e) => error!("watch error: {:?}", e),
})?;
watcher.watch(&path, RecursiveMode::NonRecursive)?;
info!("Start watching the config");
@ -91,12 +159,7 @@ async fn config_watcher(
tokio::select! {
e = fevent_rx.recv() => {
match e {
Some(e) => {
if let EventKind::Modify(kind) = e.kind {
match kind {
ModifyKind::Data(_) => (),
_ => continue
}
Some(_) => {
info!("Rescan the configuration");
let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
Ok(v) => v,
@ -107,12 +170,11 @@ async fn config_watcher(
}
};
for event in calculate_event(&old, &new) {
cfg_event_tx.send(event).await?;
for event in calculate_events(&old, &new) {
event_tx.send(event).await?;
}
old = new;
}
},
None => break
}
@ -126,74 +188,170 @@ async fn config_watcher(
Ok(())
}
fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
let mut ret = Vec::new();
if old != new {
if old.server.is_some() && new.server.is_some() {
let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
&old.server.as_ref().unwrap().services,
&new.server.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x)))
.collect();
ret.append(&mut e);
let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
&old.server.as_ref().unwrap().services,
&new.server.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x)))
.collect();
ret.append(&mut e);
} else if old.client.is_some() && new.client.is_some() {
let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
&old.client.as_ref().unwrap().services,
&new.client.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x)))
.collect();
ret.append(&mut e);
let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
&old.client.as_ref().unwrap().services,
&new.client.as_ref().unwrap().services,
)
.into_iter()
.map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x)))
.collect();
ret.append(&mut e);
fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
if old == new {
vec![]
} else if old.server != new.server {
if old.server.is_some() != new.server.is_some() {
vec![ConfigChange::General(Box::new(new.clone()))]
} else {
ret.push(ConfigChangeEvent::General(Box::new(new.clone())));
match calculate_instance_config_events(
old.server.as_ref().unwrap(),
new.server.as_ref().unwrap(),
) {
Some(v) => v,
None => vec![ConfigChange::General(Box::new(new.clone()))],
}
}
} else if old.client != new.client {
if old.client.is_some() != new.client.is_some() {
vec![ConfigChange::General(Box::new(new.clone()))]
} else {
match calculate_instance_config_events(
old.client.as_ref().unwrap(),
new.client.as_ref().unwrap(),
) {
Some(v) => v,
None => vec![ConfigChange::General(Box::new(new.clone()))],
}
}
} else {
vec![]
}
}
// None indicates a General change needed
fn calculate_instance_config_events<T: InstanceConfig>(
old: &T,
new: &T,
) -> Option<Vec<ConfigChange>> {
if !old.equal_without_service(new) {
return None;
}
ret
let old = old.get_services();
let new = new.get_services();
let mut v = vec![];
v.append(&mut calculate_service_delete_events::<T>(old, new));
v.append(&mut calculate_service_add_events(old, new));
Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
}
fn calculate_service_delete_event<T: PartialEq>(
old_services: &HashMap<String, T>,
new_services: &HashMap<String, T>,
) -> Vec<String> {
old_services
.keys()
.filter(|&name| old_services.get(name) != new_services.get(name))
.map(|x| x.to_owned())
fn calculate_service_delete_events<T: InstanceConfig>(
old: &HashMap<String, T::ServiceConfig>,
new: &HashMap<String, T::ServiceConfig>,
) -> Vec<ServiceChange> {
old.keys()
.filter(|&name| new.get(name).is_none())
.map(|x| T::to_service_change_delete(x.to_owned()))
.collect()
}
fn calculate_service_add_event<T: PartialEq + Clone>(
old_services: &HashMap<String, T>,
new_services: &HashMap<String, T>,
) -> Vec<T> {
new_services
.iter()
.filter(|(name, _)| old_services.get(*name) != new_services.get(*name))
.map(|(_, c)| c.clone())
fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
old: &HashMap<String, T>,
new: &HashMap<String, T>,
) -> Vec<ServiceChange> {
new.iter()
.filter(|(name, c)| old.get(*name) != Some(*c))
.map(|(_, c)| c.clone().into())
.collect()
}
#[cfg(test)]
mod test {
use crate::config::ServerConfig;
use super::*;
// macro to create map or set literal
macro_rules! collection {
// map-like
($($k:expr => $v:expr),* $(,)?) => {{
use std::iter::{Iterator, IntoIterator};
Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*]))
}};
}
#[test]
fn test_calculate_events() {
struct Test {
old: Config,
new: Config,
}
let tests = [
Test {
old: Config {
server: Some(Default::default()),
client: None,
},
new: Config {
server: Some(Default::default()),
client: Some(Default::default()),
},
},
Test {
old: Config {
server: Some(ServerConfig {
bind_addr: String::from("127.0.0.1:2334"),
..Default::default()
}),
client: None,
},
new: Config {
server: Some(ServerConfig {
bind_addr: String::from("127.0.0.1:2333"),
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
},
Test {
old: Config {
server: Some(Default::default()),
client: None,
},
new: Config {
server: Some(ServerConfig {
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
},
Test {
old: Config {
server: Some(ServerConfig {
services: collection!(String::from("foo") => Default::default()),
..Default::default()
}),
client: None,
},
new: Config {
server: Some(Default::default()),
client: None,
},
},
];
let expected = [
vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
Default::default(),
))],
vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
String::from("foo"),
))],
];
assert_eq!(tests.len(), expected.len());
for i in 0..tests.len() {
let actual = calculate_events(&tests[i].old, &tests[i].new);
assert_eq!(actual, expected[i]);
}
}
}

View File

@ -10,7 +10,7 @@ mod transport;
pub use cli::Cli;
use cli::KeypairType;
pub use config::Config;
use config_watcher::ServiceChangeEvent;
use config_watcher::ServiceChange;
pub use constants::UDP_BUFFER_SIZE;
use anyhow::Result;
@ -27,7 +27,7 @@ mod server;
#[cfg(feature = "server")]
use server::run_server;
use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle};
use crate::config_watcher::{ConfigChange, ConfigWatcherHandle};
const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
@ -76,12 +76,11 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
let (shutdown_tx, _) = broadcast::channel(1);
// (The join handle of the last instance, The service update channel sender)
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> =
None;
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
while let Some(e) = cfg_watcher.event_rx.recv().await {
match e {
ConfigChangeEvent::General(config) => {
ConfigChange::General(config) => {
if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?;
@ -102,7 +101,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
service_update_tx,
));
}
ConfigChangeEvent::ServiceChange(service_event) => {
ConfigChange::ServiceChange(service_event) => {
info!("Service change detcted. {:?}", service_event);
if let Some((_, service_update_tx)) = &last_instance {
let _ = service_update_tx.send(service_event).await;
@ -110,6 +109,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
}
}
}
let _ = shutdown_tx.send(true);
Ok(())
}
@ -117,7 +119,7 @@ async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
service_update: mpsc::Receiver<ServiceChangeEvent>,
service_update: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),

View File

@ -1,5 +1,5 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::config_watcher::ServiceChangeEvent;
use crate::config_watcher::ServiceChange;
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
@ -39,7 +39,7 @@ const CHAN_SIZE: usize = 2048; // The capacity of various chans
pub async fn run_server(
config: &Config,
shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChangeEvent>,
service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.server {
Some(config) => config,
@ -122,7 +122,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
// Listen at `server.bind_addr`
let l = self
@ -193,9 +193,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
Ok(())
}
async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) {
async fn handle_hot_reload(&mut self, e: ServiceChange) {
match e {
ServiceChangeEvent::ServerAdd(s) => {
ServiceChange::ServerAdd(s) => {
let hash = protocol::digest(s.name.as_bytes());
let mut wg = self.services.write().await;
let _ = wg.insert(hash, s);
@ -203,7 +203,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
let mut wg = self.control_channels.write().await;
let _ = wg.remove1(&hash);
}
ServiceChangeEvent::ServerDelete(s) => {
ServiceChange::ServerDelete(s) => {
let hash = protocol::digest(s.as_bytes());
let _ = self.services.write().await.remove(&hash);
@ -340,11 +340,8 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
}
pub struct ControlChannelHandle<T: Transport> {
// Shutdown the control channel.
// Not used for now, but can be used for hot reloading
#[allow(dead_code)]
shutdown_tx: broadcast::Sender<bool>,
//data_ch_req_tx: mpsc::Sender<bool>,
// Shutdown the control channel by dropping it
_shutdown_tx: broadcast::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
}
@ -359,7 +356,7 @@ where
// Save the name string for logging
let name = service.name.clone();
// Create a shutdown channel. The sender is not used for now, but for future use
// Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
// Store data channels
@ -417,15 +414,10 @@ where
});
ControlChannelHandle {
shutdown_tx,
_shutdown_tx: shutdown_tx,
data_ch_tx,
}
}
#[allow(dead_code)]
fn shutdown(self) {
let _ = self.shutdown_tx.send(true);
}
}
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic