mirror of https://github.com/rapiz1/rathole.git
feat: application layer heartbeat (#136)
* feat: application layer heartbeat * feat: make heartbeat configurable * fix: update keepalive params * docs: update about heartbeat
This commit is contained in:
parent
1ef7747019
commit
2746a0ea88
10
README.md
10
README.md
|
@ -105,6 +105,7 @@ Here is the full configuration specification:
|
|||
[client]
|
||||
remote_addr = "example.com:2333" # Necessary. The address of the server
|
||||
default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
|
||||
heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 secs
|
||||
|
||||
[client.transport] # The whole block is optional. Specify which transport to use
|
||||
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp"
|
||||
|
@ -112,8 +113,8 @@ type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp
|
|||
[client.transport.tcp] # Optional
|
||||
proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. Use the proxy to connect to the server
|
||||
nodelay = false # Optional. Determine whether to enable TCP_NODELAY, if applicable, to improve the latency but decrease the bandwidth. Default: false
|
||||
keepalive_secs = 10 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 10 seconds
|
||||
keepalive_interval = 5 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 5 seconds
|
||||
keepalive_secs = 20 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 20 seconds
|
||||
keepalive_interval = 8 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 8 seconds
|
||||
|
||||
[client.transport.tls] # Necessary if `type` is "tls"
|
||||
trusted_root = "ca.pem" # Necessary. The certificate of CA that signed the server's certificate
|
||||
|
@ -136,12 +137,13 @@ local_addr = "127.0.0.1:1082"
|
|||
[server]
|
||||
bind_addr = "0.0.0.0:2333" # Necessary. The address that the server listens for clients. Generally only the port needs to be change.
|
||||
default_token = "default_token_if_not_specify" # Optional
|
||||
heartbeat_interval = 30 # Optional. The interval between two application-layer heartbeat. Set to 0 to disable sending heartbeat. Default: 30 secs
|
||||
|
||||
[server.transport] # Same as `[client.transport]`
|
||||
type = "tcp"
|
||||
nodelay = false
|
||||
keepalive_secs = 10
|
||||
keepalive_interval = 5
|
||||
keepalive_secs = 20
|
||||
keepalive_interval = 8
|
||||
|
||||
[server.transport.tls] # Necessary if `type` is "tls"
|
||||
pkcs12 = "identify.pfx" # Necessary. pkcs12 file of server's certificate and private key
|
||||
|
|
|
@ -29,11 +29,11 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
|
|||
|
||||
// The entrypoint of running a client
|
||||
pub async fn run_client(
|
||||
config: &Config,
|
||||
config: Config,
|
||||
shutdown_rx: broadcast::Receiver<bool>,
|
||||
service_rx: mpsc::Receiver<ServiceChange>,
|
||||
) -> Result<()> {
|
||||
let config = config.client.as_ref().ok_or(anyhow!(
|
||||
let config = config.client.ok_or(anyhow!(
|
||||
"Try to run as a client, but the configuration is missing. Please add the `[client]` block"
|
||||
))?;
|
||||
|
||||
|
@ -67,21 +67,21 @@ type ServiceDigest = protocol::Digest;
|
|||
type Nonce = protocol::Digest;
|
||||
|
||||
// Holds the state of a client
|
||||
struct Client<'a, T: Transport> {
|
||||
config: &'a ClientConfig,
|
||||
struct Client<T: Transport> {
|
||||
config: ClientConfig,
|
||||
service_handles: HashMap<String, ControlChannelHandle>,
|
||||
transport: Arc<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'static + Transport> Client<'a, T> {
|
||||
impl<T: 'static + Transport> Client<T> {
|
||||
// Create a Client from `[client]` config block
|
||||
async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
|
||||
async fn from(config: ClientConfig) -> Result<Client<T>> {
|
||||
let transport =
|
||||
Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
|
||||
Ok(Client {
|
||||
config,
|
||||
service_handles: HashMap::new(),
|
||||
transport: Arc::new(
|
||||
T::new(&config.transport).with_context(|| "Failed to create the transport")?,
|
||||
),
|
||||
transport,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -97,6 +97,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
|||
(*config).clone(),
|
||||
self.config.remote_addr.clone(),
|
||||
self.transport.clone(),
|
||||
self.config.heartbeat_timeout,
|
||||
);
|
||||
self.service_handles.insert(name.clone(), handle);
|
||||
}
|
||||
|
@ -122,6 +123,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
|||
s,
|
||||
self.config.remote_addr.clone(),
|
||||
self.transport.clone(),
|
||||
self.config.heartbeat_timeout
|
||||
);
|
||||
let _ = self.service_handles.insert(name, handle);
|
||||
},
|
||||
|
@ -369,6 +371,7 @@ struct ControlChannel<T: Transport> {
|
|||
shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
|
||||
remote_addr: String, // `client.remote_addr`
|
||||
transport: Arc<T>, // Wrapper around the transport layer
|
||||
heartbeat_timeout: u64, // Application layer heartbeat timeout in secs
|
||||
}
|
||||
|
||||
// Handle of a control channel
|
||||
|
@ -451,9 +454,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
|
|||
warn!("{:#}", e);
|
||||
}
|
||||
}.instrument(Span::current()));
|
||||
}
|
||||
},
|
||||
ControlChannelCmd::HeartBeat => ()
|
||||
}
|
||||
},
|
||||
_ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => {
|
||||
warn!("Heartbeat timed out");
|
||||
break;
|
||||
}
|
||||
_ = &mut self.shutdown_rx => {
|
||||
break;
|
||||
}
|
||||
|
@ -471,6 +479,7 @@ impl ControlChannelHandle {
|
|||
service: ClientServiceConfig,
|
||||
remote_addr: String,
|
||||
transport: Arc<T>,
|
||||
heartbeat_timeout: u64,
|
||||
) -> ControlChannelHandle {
|
||||
let digest = protocol::digest(service.name.as_bytes());
|
||||
|
||||
|
@ -482,6 +491,7 @@ impl ControlChannelHandle {
|
|||
shutdown_rx,
|
||||
remote_addr,
|
||||
transport,
|
||||
heartbeat_timeout,
|
||||
};
|
||||
|
||||
tokio::spawn(
|
||||
|
|
|
@ -9,6 +9,10 @@ use url::Url;
|
|||
|
||||
use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};
|
||||
|
||||
/// Application-layer heartbeat interval in secs
|
||||
const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30;
|
||||
const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40;
|
||||
|
||||
/// String with Debug implementation that emits "MASKED"
|
||||
/// Used to mask sensitive strings when logging
|
||||
#[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
|
||||
|
@ -177,6 +181,10 @@ pub struct TransportConfig {
|
|||
pub noise: Option<NoiseConfig>,
|
||||
}
|
||||
|
||||
fn default_heartbeat_timeout() -> u64 {
|
||||
DEFAULT_HEARTBEAT_TIMEOUT_SECS
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ClientConfig {
|
||||
|
@ -185,6 +193,12 @@ pub struct ClientConfig {
|
|||
pub services: HashMap<String, ClientServiceConfig>,
|
||||
#[serde(default)]
|
||||
pub transport: TransportConfig,
|
||||
#[serde(default = "default_heartbeat_timeout")]
|
||||
pub heartbeat_timeout: u64,
|
||||
}
|
||||
|
||||
fn default_heartbeat_interval() -> u64 {
|
||||
DEFAULT_HEARTBEAT_INTERVAL_SECS
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
|
||||
|
@ -195,6 +209,8 @@ pub struct ServerConfig {
|
|||
pub services: HashMap<String, ServerServiceConfig>,
|
||||
#[serde(default)]
|
||||
pub transport: TransportConfig,
|
||||
#[serde(default = "default_heartbeat_interval")]
|
||||
pub heartbeat_interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
|
|
|
@ -93,7 +93,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
|
|||
|
||||
last_instance = Some((
|
||||
tokio::spawn(run_instance(
|
||||
*(config.clone()),
|
||||
*config,
|
||||
args.clone(),
|
||||
shutdown_tx.subscribe(),
|
||||
service_update_rx,
|
||||
|
@ -127,13 +127,13 @@ async fn run_instance(
|
|||
#[cfg(not(feature = "client"))]
|
||||
crate::helper::feature_not_compile("client");
|
||||
#[cfg(feature = "client")]
|
||||
run_client(&config, shutdown_rx, service_update).await
|
||||
run_client(config, shutdown_rx, service_update).await
|
||||
}
|
||||
RunMode::Server => {
|
||||
#[cfg(not(feature = "server"))]
|
||||
crate::helper::feature_not_compile("server");
|
||||
#[cfg(feature = "server")]
|
||||
run_server(&config, shutdown_rx, service_update).await
|
||||
run_server(config, shutdown_rx, service_update).await
|
||||
}
|
||||
};
|
||||
ret.unwrap();
|
||||
|
|
|
@ -9,9 +9,10 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|||
use tracing::trace;
|
||||
|
||||
type ProtocolVersion = u8;
|
||||
const PROTO_V0: u8 = 0u8;
|
||||
const _PROTO_V0: u8 = 0u8;
|
||||
const PROTO_V1: u8 = 1u8;
|
||||
|
||||
pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V0;
|
||||
pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V1;
|
||||
|
||||
pub type Digest = [u8; HASH_WIDTH_IN_BYTES];
|
||||
|
||||
|
@ -48,6 +49,7 @@ impl std::fmt::Display for Ack {
|
|||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum ControlChannelCmd {
|
||||
CreateDataChannel,
|
||||
HeartBeat,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
|
|
|
@ -38,11 +38,11 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
|
|||
|
||||
// The entrypoint of running a server
|
||||
pub async fn run_server(
|
||||
config: &Config,
|
||||
config: Config,
|
||||
shutdown_rx: broadcast::Receiver<bool>,
|
||||
service_rx: mpsc::Receiver<ServiceChange>,
|
||||
) -> Result<()> {
|
||||
let config = match &config.server {
|
||||
let config = match config.server {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
|
||||
|
@ -82,9 +82,9 @@ pub async fn run_server(
|
|||
type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
|
||||
|
||||
// Server holds all states of running a server
|
||||
struct Server<'a, T: Transport> {
|
||||
struct Server<T: Transport> {
|
||||
// `[server]` config
|
||||
config: &'a ServerConfig,
|
||||
config: Arc<ServerConfig>,
|
||||
|
||||
// `[server.services]` config, indexed by ServiceDigest
|
||||
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
||||
|
@ -105,14 +105,18 @@ fn generate_service_hashmap(
|
|||
ret
|
||||
}
|
||||
|
||||
impl<'a, T: 'static + Transport> Server<'a, T> {
|
||||
impl<T: 'static + Transport> Server<T> {
|
||||
// Create a server from `[server]`
|
||||
pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
|
||||
pub async fn from(config: ServerConfig) -> Result<Server<T>> {
|
||||
let config = Arc::new(config);
|
||||
let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
|
||||
let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
|
||||
let transport = Arc::new(T::new(&config.transport)?);
|
||||
Ok(Server {
|
||||
config,
|
||||
services: Arc::new(RwLock::new(generate_service_hashmap(config))),
|
||||
control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
|
||||
transport: Arc::new(T::new(&config.transport)?),
|
||||
services,
|
||||
control_channels,
|
||||
transport,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -171,8 +175,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
|
|||
Ok(conn) => {
|
||||
let services = self.services.clone();
|
||||
let control_channels = self.control_channels.clone();
|
||||
let server_config = self.config.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_connection(conn, services, control_channels).await {
|
||||
if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
|
||||
error!("{:#}", err);
|
||||
}
|
||||
}.instrument(info_span!("connection", %addr)));
|
||||
|
@ -233,12 +238,20 @@ async fn handle_connection<T: 'static + Transport>(
|
|||
mut conn: T::Stream,
|
||||
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
||||
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
||||
server_config: Arc<ServerConfig>,
|
||||
) -> Result<()> {
|
||||
// Read hello
|
||||
let hello = read_hello(&mut conn).await?;
|
||||
match hello {
|
||||
ControlChannelHello(_, service_digest) => {
|
||||
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
|
||||
do_control_channel_handshake(
|
||||
conn,
|
||||
services,
|
||||
control_channels,
|
||||
service_digest,
|
||||
server_config,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
DataChannelHello(_, nonce) => {
|
||||
do_data_channel_handshake(conn, control_channels, nonce).await?;
|
||||
|
@ -252,6 +265,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
|
|||
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
||||
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
||||
service_digest: ServiceDigest,
|
||||
server_config: Arc<ServerConfig>,
|
||||
) -> Result<()> {
|
||||
info!("Try to handshake a control channel");
|
||||
|
||||
|
@ -321,7 +335,8 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
|
|||
conn.flush().await?;
|
||||
|
||||
info!(service = %service_config.name, "Control channel established");
|
||||
let handle = ControlChannelHandle::new(conn, service_config);
|
||||
let handle =
|
||||
ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval);
|
||||
|
||||
// Insert the new handle
|
||||
let _ = h.insert(service_digest, session_key, handle);
|
||||
|
@ -371,7 +386,11 @@ where
|
|||
// Create a control channel handle, where the control channel handling task
|
||||
// and the connection pool task are created.
|
||||
#[instrument(name = "handle", skip_all, fields(service = %service.name))]
|
||||
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
||||
fn new(
|
||||
conn: T::Stream,
|
||||
service: ServerServiceConfig,
|
||||
heartbeat_interval: u64,
|
||||
) -> ControlChannelHandle<T> {
|
||||
// Create a shutdown channel
|
||||
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
||||
|
||||
|
@ -435,6 +454,7 @@ where
|
|||
conn,
|
||||
shutdown_rx,
|
||||
data_ch_req_rx,
|
||||
heartbeat_interval,
|
||||
};
|
||||
|
||||
// Run the control channel
|
||||
|
@ -460,13 +480,26 @@ struct ControlChannel<T: Transport> {
|
|||
conn: T::Stream, // The connection of control channel
|
||||
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
|
||||
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
|
||||
heartbeat_interval: u64, // Application-layer heartbeat interval in secs
|
||||
}
|
||||
|
||||
impl<T: Transport> ControlChannel<T> {
|
||||
async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> {
|
||||
self.conn
|
||||
.write_all(data)
|
||||
.await
|
||||
.with_context(|| "Failed to write control cmds")?;
|
||||
self.conn
|
||||
.flush()
|
||||
.await
|
||||
.with_context(|| "Failed to flush control cmds")?;
|
||||
Ok(())
|
||||
}
|
||||
// Run a control channel
|
||||
#[instrument(skip_all)]
|
||||
async fn run(mut self) -> Result<()> {
|
||||
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
||||
let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
||||
let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap();
|
||||
|
||||
// Wait for data channel requests and the shutdown signal
|
||||
loop {
|
||||
|
@ -474,11 +507,7 @@ impl<T: Transport> ControlChannel<T> {
|
|||
val = self.data_ch_req_rx.recv() => {
|
||||
match val {
|
||||
Some(_) => {
|
||||
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write control cmds") {
|
||||
error!("{:#}", e);
|
||||
break;
|
||||
}
|
||||
if let Err(e) = self.conn.flush().await.with_context(|| "Failed to flush control cmds") {
|
||||
if let Err(e) = self.write_and_flush(&create_ch_cmd).await {
|
||||
error!("{:#}", e);
|
||||
break;
|
||||
}
|
||||
|
@ -488,6 +517,12 @@ impl<T: Transport> ControlChannel<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
_ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => {
|
||||
if let Err(e) = self.write_and_flush(&heartbeat).await {
|
||||
error!("{:#}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Wait for the shutdown signal
|
||||
_ = self.shutdown_rx.recv() => {
|
||||
break;
|
||||
|
|
|
@ -9,10 +9,10 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
|||
use tokio::net::{TcpStream, ToSocketAddrs};
|
||||
use tracing::{error, trace};
|
||||
|
||||
pub static DEFAULT_NODELAY: bool = false;
|
||||
pub const DEFAULT_NODELAY: bool = false;
|
||||
|
||||
pub static DEFAULT_KEEPALIVE_SECS: u64 = 10;
|
||||
pub static DEFAULT_KEEPALIVE_INTERVAL: u64 = 3;
|
||||
pub const DEFAULT_KEEPALIVE_SECS: u64 = 20;
|
||||
pub const DEFAULT_KEEPALIVE_INTERVAL: u64 = 8;
|
||||
|
||||
/// Specify a transport layer, like TCP, TLS
|
||||
#[async_trait]
|
||||
|
|
Loading…
Reference in New Issue