mirror of https://github.com/rapiz1/rathole.git
refactor: facilitate tests
This commit is contained in:
parent
776bce35cb
commit
b8e824849a
|
@ -13,12 +13,12 @@ use backoff::ExponentialBackoff;
|
|||
|
||||
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::{broadcast, oneshot};
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{debug, error, info, instrument, Instrument, Span};
|
||||
|
||||
// The entrypoint of running a client
|
||||
pub async fn run_client(config: &Config) -> Result<()> {
|
||||
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
let config = match &config.client {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
|
@ -29,11 +29,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
|
|||
match config.transport.transport_type {
|
||||
TransportType::Tcp => {
|
||||
let mut client = Client::<TcpTransport>::from(config).await?;
|
||||
client.run().await
|
||||
client.run(shutdown_rx).await
|
||||
}
|
||||
TransportType::Tls => {
|
||||
let mut client = Client::<TlsTransport>::from(config).await?;
|
||||
client.run().await
|
||||
client.run(shutdown_rx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -54,12 +54,16 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
|||
Ok(Client {
|
||||
config,
|
||||
service_handles: HashMap::new(),
|
||||
transport: Arc::new(*T::new(&config.transport).await?),
|
||||
transport: Arc::new(
|
||||
*T::new(&config.transport)
|
||||
.await
|
||||
.with_context(|| "Failed to create the transport")?,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
// The entrypoint of Client
|
||||
async fn run(&mut self) -> Result<()> {
|
||||
async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
for (name, config) in &self.config.services {
|
||||
// Create a control channel for each service defined
|
||||
let handle = ControlChannelHandle::new(
|
||||
|
@ -74,9 +78,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
|||
// Wait for the shutdown signal
|
||||
loop {
|
||||
tokio::select! {
|
||||
val = tokio::signal::ctrl_c() => {
|
||||
val = shutdown_rx.recv() => {
|
||||
match val {
|
||||
Ok(()) => {}
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
error!("Unable to listen for shutdown signal: {}", err);
|
||||
}
|
||||
|
@ -258,7 +262,7 @@ impl ControlChannelHandle {
|
|||
.await
|
||||
.with_context(|| "Failed to run the control channel")
|
||||
{
|
||||
let duration = Duration::from_secs(2);
|
||||
let duration = Duration::from_secs(1);
|
||||
error!("{:?}\n\nRetry in {:?}...", err, duration);
|
||||
time::sleep(duration).await;
|
||||
}
|
||||
|
|
|
@ -11,16 +11,15 @@ pub use cli::Cli;
|
|||
pub use config::Config;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::debug;
|
||||
|
||||
use client::run_client;
|
||||
use server::run_server;
|
||||
|
||||
pub async fn run(args: &Cli) -> Result<()> {
|
||||
pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
let config = Config::from_file(&args.config_path).await?;
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
debug!("{:?}", config);
|
||||
|
||||
// Raise `nofile` limit on linux and mac
|
||||
|
@ -28,8 +27,8 @@ pub async fn run(args: &Cli) -> Result<()> {
|
|||
|
||||
match determine_run_mode(&config, args) {
|
||||
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
|
||||
RunMode::Client => run_client(&config).await,
|
||||
RunMode::Server => run_server(&config).await,
|
||||
RunMode::Client => run_client(&config, shutdown_rx).await,
|
||||
RunMode::Server => run_server(&config, shutdown_rx).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
21
src/main.rs
21
src/main.rs
|
@ -1,9 +1,28 @@
|
|||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use rathole::{run, Cli};
|
||||
use tokio::{signal, sync::broadcast};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let args = Cli::parse();
|
||||
run(&args).await
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = signal::ctrl_c().await {
|
||||
// Something really weird happened. So just panic
|
||||
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = shutdown_tx.send(true) {
|
||||
// shutdown signal must be catched and handle properly
|
||||
// `rx` must not be dropped
|
||||
panic!("Failed to send shutdown signal: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: use level from config
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
run(&args, shutdown_rx).await
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
|
@ -26,7 +26,7 @@ const POOL_SIZE: usize = 64; // The number of cached connections
|
|||
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
||||
|
||||
// The entrypoint of running a server
|
||||
pub async fn run_server(config: &Config) -> Result<()> {
|
||||
pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
let config = match &config.server {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
|
@ -38,11 +38,11 @@ pub async fn run_server(config: &Config) -> Result<()> {
|
|||
match config.transport.transport_type {
|
||||
TransportType::Tcp => {
|
||||
let mut server = Server::<TcpTransport>::from(config).await?;
|
||||
server.run().await?;
|
||||
server.run(shutdown_rx).await?;
|
||||
}
|
||||
TransportType::Tls => {
|
||||
let mut server = Server::<TlsTransport>::from(config).await?;
|
||||
server.run().await?;
|
||||
server.run(shutdown_rx).await?;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
|
|||
}
|
||||
|
||||
// The entry point of Server
|
||||
pub async fn run(&mut self) -> Result<()> {
|
||||
pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
// Listen at `server.bind_addr`
|
||||
let l = self
|
||||
.transport
|
||||
|
@ -146,7 +146,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
|
|||
}
|
||||
},
|
||||
// Wait for the shutdown signal
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("Shuting down gracefully...");
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue