mirror of https://github.com/rapiz1/rathole.git
refactor: facilitate tests
This commit is contained in:
parent
776bce35cb
commit
b8e824849a
|
@ -13,12 +13,12 @@ use backoff::ExponentialBackoff;
|
||||||
|
|
||||||
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::{broadcast, oneshot};
|
||||||
use tokio::time::{self, Duration};
|
use tokio::time::{self, Duration};
|
||||||
use tracing::{debug, error, info, instrument, Instrument, Span};
|
use tracing::{debug, error, info, instrument, Instrument, Span};
|
||||||
|
|
||||||
// The entrypoint of running a client
|
// The entrypoint of running a client
|
||||||
pub async fn run_client(config: &Config) -> Result<()> {
|
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
let config = match &config.client {
|
let config = match &config.client {
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => {
|
None => {
|
||||||
|
@ -29,11 +29,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
|
||||||
match config.transport.transport_type {
|
match config.transport.transport_type {
|
||||||
TransportType::Tcp => {
|
TransportType::Tcp => {
|
||||||
let mut client = Client::<TcpTransport>::from(config).await?;
|
let mut client = Client::<TcpTransport>::from(config).await?;
|
||||||
client.run().await
|
client.run(shutdown_rx).await
|
||||||
}
|
}
|
||||||
TransportType::Tls => {
|
TransportType::Tls => {
|
||||||
let mut client = Client::<TlsTransport>::from(config).await?;
|
let mut client = Client::<TlsTransport>::from(config).await?;
|
||||||
client.run().await
|
client.run(shutdown_rx).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -54,12 +54,16 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
||||||
Ok(Client {
|
Ok(Client {
|
||||||
config,
|
config,
|
||||||
service_handles: HashMap::new(),
|
service_handles: HashMap::new(),
|
||||||
transport: Arc::new(*T::new(&config.transport).await?),
|
transport: Arc::new(
|
||||||
|
*T::new(&config.transport)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to create the transport")?,
|
||||||
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// The entrypoint of Client
|
// The entrypoint of Client
|
||||||
async fn run(&mut self) -> Result<()> {
|
async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
for (name, config) in &self.config.services {
|
for (name, config) in &self.config.services {
|
||||||
// Create a control channel for each service defined
|
// Create a control channel for each service defined
|
||||||
let handle = ControlChannelHandle::new(
|
let handle = ControlChannelHandle::new(
|
||||||
|
@ -74,9 +78,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
||||||
// Wait for the shutdown signal
|
// Wait for the shutdown signal
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
val = tokio::signal::ctrl_c() => {
|
val = shutdown_rx.recv() => {
|
||||||
match val {
|
match val {
|
||||||
Ok(()) => {}
|
Ok(_) => {}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Unable to listen for shutdown signal: {}", err);
|
error!("Unable to listen for shutdown signal: {}", err);
|
||||||
}
|
}
|
||||||
|
@ -258,7 +262,7 @@ impl ControlChannelHandle {
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to run the control channel")
|
.with_context(|| "Failed to run the control channel")
|
||||||
{
|
{
|
||||||
let duration = Duration::from_secs(2);
|
let duration = Duration::from_secs(1);
|
||||||
error!("{:?}\n\nRetry in {:?}...", err, duration);
|
error!("{:?}\n\nRetry in {:?}...", err, duration);
|
||||||
time::sleep(duration).await;
|
time::sleep(duration).await;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,16 +11,15 @@ pub use cli::Cli;
|
||||||
pub use config::Config;
|
pub use config::Config;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
use tokio::sync::broadcast;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use client::run_client;
|
use client::run_client;
|
||||||
use server::run_server;
|
use server::run_server;
|
||||||
|
|
||||||
pub async fn run(args: &Cli) -> Result<()> {
|
pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
let config = Config::from_file(&args.config_path).await?;
|
let config = Config::from_file(&args.config_path).await?;
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
debug!("{:?}", config);
|
debug!("{:?}", config);
|
||||||
|
|
||||||
// Raise `nofile` limit on linux and mac
|
// Raise `nofile` limit on linux and mac
|
||||||
|
@ -28,8 +27,8 @@ pub async fn run(args: &Cli) -> Result<()> {
|
||||||
|
|
||||||
match determine_run_mode(&config, args) {
|
match determine_run_mode(&config, args) {
|
||||||
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
|
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
|
||||||
RunMode::Client => run_client(&config).await,
|
RunMode::Client => run_client(&config, shutdown_rx).await,
|
||||||
RunMode::Server => run_server(&config).await,
|
RunMode::Server => run_server(&config, shutdown_rx).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
21
src/main.rs
21
src/main.rs
|
@ -1,9 +1,28 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rathole::{run, Cli};
|
use rathole::{run, Cli};
|
||||||
|
use tokio::{signal, sync::broadcast};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
let args = Cli::parse();
|
let args = Cli::parse();
|
||||||
run(&args).await
|
|
||||||
|
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = signal::ctrl_c().await {
|
||||||
|
// Something really weird happened. So just panic
|
||||||
|
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = shutdown_tx.send(true) {
|
||||||
|
// shutdown signal must be catched and handle properly
|
||||||
|
// `rx` must not be dropped
|
||||||
|
panic!("Failed to send shutdown signal: {:?}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// TODO: use level from config
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
run(&args, shutdown_rx).await
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ const POOL_SIZE: usize = 64; // The number of cached connections
|
||||||
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
||||||
|
|
||||||
// The entrypoint of running a server
|
// The entrypoint of running a server
|
||||||
pub async fn run_server(config: &Config) -> Result<()> {
|
pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
let config = match &config.server {
|
let config = match &config.server {
|
||||||
Some(config) => config,
|
Some(config) => config,
|
||||||
None => {
|
None => {
|
||||||
|
@ -38,11 +38,11 @@ pub async fn run_server(config: &Config) -> Result<()> {
|
||||||
match config.transport.transport_type {
|
match config.transport.transport_type {
|
||||||
TransportType::Tcp => {
|
TransportType::Tcp => {
|
||||||
let mut server = Server::<TcpTransport>::from(config).await?;
|
let mut server = Server::<TcpTransport>::from(config).await?;
|
||||||
server.run().await?;
|
server.run(shutdown_rx).await?;
|
||||||
}
|
}
|
||||||
TransportType::Tls => {
|
TransportType::Tls => {
|
||||||
let mut server = Server::<TlsTransport>::from(config).await?;
|
let mut server = Server::<TlsTransport>::from(config).await?;
|
||||||
server.run().await?;
|
server.run(shutdown_rx).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The entry point of Server
|
// The entry point of Server
|
||||||
pub async fn run(&mut self) -> Result<()> {
|
pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
// Listen at `server.bind_addr`
|
// Listen at `server.bind_addr`
|
||||||
let l = self
|
let l = self
|
||||||
.transport
|
.transport
|
||||||
|
@ -146,7 +146,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Wait for the shutdown signal
|
// Wait for the shutdown signal
|
||||||
_ = tokio::signal::ctrl_c() => {
|
_ = shutdown_rx.recv() => {
|
||||||
info!("Shuting down gracefully...");
|
info!("Shuting down gracefully...");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue