refactor: facilitate tests

This commit is contained in:
Yujia Qiao 2021-12-19 11:31:58 +08:00
parent 776bce35cb
commit b8e824849a
No known key found for this signature in database
GPG Key ID: DC129173B148701B
4 changed files with 43 additions and 21 deletions

View File

@ -13,12 +13,12 @@ use backoff::ExponentialBackoff;
use tokio::io::{copy_bidirectional, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::{broadcast, oneshot};
use tokio::time::{self, Duration};
use tracing::{debug, error, info, instrument, Instrument, Span};
// The entrypoint of running a client
pub async fn run_client(config: &Config) -> Result<()> {
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
let config = match &config.client {
Some(v) => v,
None => {
@ -29,11 +29,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
match config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
client.run().await
client.run(shutdown_rx).await
}
TransportType::Tls => {
let mut client = Client::<TlsTransport>::from(config).await?;
client.run().await
client.run(shutdown_rx).await
}
}
}
@ -54,12 +54,16 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
Ok(Client {
config,
service_handles: HashMap::new(),
transport: Arc::new(*T::new(&config.transport).await?),
transport: Arc::new(
*T::new(&config.transport)
.await
.with_context(|| "Failed to create the transport")?,
),
})
}
// The entrypoint of Client
async fn run(&mut self) -> Result<()> {
async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
let handle = ControlChannelHandle::new(
@ -74,9 +78,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
// Wait for the shutdown signal
loop {
tokio::select! {
val = tokio::signal::ctrl_c() => {
val = shutdown_rx.recv() => {
match val {
Ok(()) => {}
Ok(_) => {}
Err(err) => {
error!("Unable to listen for shutdown signal: {}", err);
}
@ -258,7 +262,7 @@ impl ControlChannelHandle {
.await
.with_context(|| "Failed to run the control channel")
{
let duration = Duration::from_secs(2);
let duration = Duration::from_secs(1);
error!("{:?}\n\nRetry in {:?}...", err, duration);
time::sleep(duration).await;
}

View File

@ -11,16 +11,15 @@ pub use cli::Cli;
pub use config::Config;
use anyhow::{anyhow, Result};
use tokio::sync::broadcast;
use tracing::debug;
use client::run_client;
use server::run_server;
pub async fn run(args: &Cli) -> Result<()> {
pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
let config = Config::from_file(&args.config_path).await?;
tracing_subscriber::fmt::init();
debug!("{:?}", config);
// Raise `nofile` limit on linux and mac
@ -28,8 +27,8 @@ pub async fn run(args: &Cli) -> Result<()> {
match determine_run_mode(&config, args) {
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
RunMode::Client => run_client(&config).await,
RunMode::Server => run_server(&config).await,
RunMode::Client => run_client(&config, shutdown_rx).await,
RunMode::Server => run_server(&config, shutdown_rx).await,
}
}

View File

@ -1,9 +1,28 @@
use anyhow::Result;
use clap::Parser;
use rathole::{run, Cli};
use tokio::{signal, sync::broadcast};
#[tokio::main]
async fn main() -> Result<()> {
let args = Cli::parse();
run(&args).await
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
tokio::spawn(async move {
if let Err(e) = signal::ctrl_c().await {
// Something really weird happened. So just panic
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
}
if let Err(e) = shutdown_tx.send(true) {
// shutdown signal must be catched and handle properly
// `rx` must not be dropped
panic!("Failed to send shutdown signal: {:?}", e);
}
});
// TODO: use level from config
tracing_subscriber::fmt::init();
run(&args, shutdown_rx).await
}

View File

@ -15,7 +15,7 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time;
use tracing::{debug, error, info, info_span, warn, Instrument};
@ -26,7 +26,7 @@ const POOL_SIZE: usize = 64; // The number of cached connections
const CHAN_SIZE: usize = 2048; // The capacity of various chans
// The entrypoint of running a server
pub async fn run_server(config: &Config) -> Result<()> {
pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
let config = match &config.server {
Some(config) => config,
None => {
@ -38,11 +38,11 @@ pub async fn run_server(config: &Config) -> Result<()> {
match config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
server.run().await?;
server.run(shutdown_rx).await?;
}
TransportType::Tls => {
let mut server = Server::<TlsTransport>::from(config).await?;
server.run().await?;
server.run(shutdown_rx).await?;
}
}
@ -91,7 +91,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
}
// The entry point of Server
pub async fn run(&mut self) -> Result<()> {
pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
// Listen at `server.bind_addr`
let l = self
.transport
@ -146,7 +146,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
}
},
// Wait for the shutdown signal
_ = tokio::signal::ctrl_c() => {
_ = shutdown_rx.recv() => {
info!("Shuting down gracefully...");
break;
}