From b8e824849a433aee4f08c3dd68fe463b16738ba3 Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Sun, 19 Dec 2021 11:31:58 +0800 Subject: [PATCH] refactor: facilitate tests --- src/client.rs | 22 +++++++++++++--------- src/lib.rs | 9 ++++----- src/main.rs | 21 ++++++++++++++++++++- src/server.rs | 12 ++++++------ 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/client.rs b/src/client.rs index 968a37f..32f6c50 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,12 +13,12 @@ use backoff::ExponentialBackoff; use tokio::io::{copy_bidirectional, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio::sync::oneshot; +use tokio::sync::{broadcast, oneshot}; use tokio::time::{self, Duration}; use tracing::{debug, error, info, instrument, Instrument, Span}; // The entrypoint of running a client -pub async fn run_client(config: &Config) -> Result<()> { +pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver) -> Result<()> { let config = match &config.client { Some(v) => v, None => { @@ -29,11 +29,11 @@ pub async fn run_client(config: &Config) -> Result<()> { match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; - client.run().await + client.run(shutdown_rx).await } TransportType::Tls => { let mut client = Client::::from(config).await?; - client.run().await + client.run(shutdown_rx).await } } } @@ -54,12 +54,16 @@ impl<'a, T: 'static + Transport> Client<'a, T> { Ok(Client { config, service_handles: HashMap::new(), - transport: Arc::new(*T::new(&config.transport).await?), + transport: Arc::new( + *T::new(&config.transport) + .await + .with_context(|| "Failed to create the transport")?, + ), }) } // The entrypoint of Client - async fn run(&mut self) -> Result<()> { + async fn run(&mut self, mut shutdown_rx: broadcast::Receiver) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined let handle = ControlChannelHandle::new( @@ -74,9 +78,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> { // Wait for the shutdown signal loop { tokio::select! { - val = tokio::signal::ctrl_c() => { + val = shutdown_rx.recv() => { match val { - Ok(()) => {} + Ok(_) => {} Err(err) => { error!("Unable to listen for shutdown signal: {}", err); } @@ -258,7 +262,7 @@ impl ControlChannelHandle { .await .with_context(|| "Failed to run the control channel") { - let duration = Duration::from_secs(2); + let duration = Duration::from_secs(1); error!("{:?}\n\nRetry in {:?}...", err, duration); time::sleep(duration).await; } diff --git a/src/lib.rs b/src/lib.rs index 0f1d9a2..de79348 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,16 +11,15 @@ pub use cli::Cli; pub use config::Config; use anyhow::{anyhow, Result}; +use tokio::sync::broadcast; use tracing::debug; use client::run_client; use server::run_server; -pub async fn run(args: &Cli) -> Result<()> { +pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver) -> Result<()> { let config = Config::from_file(&args.config_path).await?; - tracing_subscriber::fmt::init(); - debug!("{:?}", config); // Raise `nofile` limit on linux and mac @@ -28,8 +27,8 @@ pub async fn run(args: &Cli) -> Result<()> { match determine_run_mode(&config, args) { RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")), - RunMode::Client => run_client(&config).await, - RunMode::Server => run_server(&config).await, + RunMode::Client => run_client(&config, shutdown_rx).await, + RunMode::Server => run_server(&config, shutdown_rx).await, } } diff --git a/src/main.rs b/src/main.rs index 84c16d6..e8f4cde 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,28 @@ use anyhow::Result; use clap::Parser; use rathole::{run, Cli}; +use tokio::{signal, sync::broadcast}; #[tokio::main] async fn main() -> Result<()> { let args = Cli::parse(); - run(&args).await + + let (shutdown_tx, shutdown_rx) = broadcast::channel::(1); + tokio::spawn(async move { + if let Err(e) = signal::ctrl_c().await { + // Something really weird happened. So just panic + panic!("Failed to listen for the ctrl-c signal: {:?}", e); + } + + if let Err(e) = shutdown_tx.send(true) { + // shutdown signal must be catched and handle properly + // `rx` must not be dropped + panic!("Failed to send shutdown signal: {:?}", e); + } + }); + + // TODO: use level from config + tracing_subscriber::fmt::init(); + + run(&args, shutdown_rx).await } diff --git a/src/server.rs b/src/server.rs index 1668053..e74edd0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{mpsc, oneshot, RwLock}; +use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -26,7 +26,7 @@ const POOL_SIZE: usize = 64; // The number of cached connections const CHAN_SIZE: usize = 2048; // The capacity of various chans // The entrypoint of running a server -pub async fn run_server(config: &Config) -> Result<()> { +pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver) -> Result<()> { let config = match &config.server { Some(config) => config, None => { @@ -38,11 +38,11 @@ pub async fn run_server(config: &Config) -> Result<()> { match config.transport.transport_type { TransportType::Tcp => { let mut server = Server::::from(config).await?; - server.run().await?; + server.run(shutdown_rx).await?; } TransportType::Tls => { let mut server = Server::::from(config).await?; - server.run().await?; + server.run(shutdown_rx).await?; } } @@ -91,7 +91,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> { } // The entry point of Server - pub async fn run(&mut self) -> Result<()> { + pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver) -> Result<()> { // Listen at `server.bind_addr` let l = self .transport @@ -146,7 +146,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> { } }, // Wait for the shutdown signal - _ = tokio::signal::ctrl_c() => { + _ = shutdown_rx.recv() => { info!("Shuting down gracefully..."); break; }