mirror of https://github.com/rapiz1/rathole.git
fix: throw errors when the service type or protocol version doesn't match (#112)
* fix: print errors when service types don't match * fix: validate the protocol version when handshake
This commit is contained in:
parent
a66502d33b
commit
cdbf8781e4
|
@ -1,4 +1,4 @@
|
||||||
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
|
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
|
||||||
use crate::config_watcher::ServiceChange;
|
use crate::config_watcher::ServiceChange;
|
||||||
use crate::helper::udp_connect;
|
use crate::helper::udp_connect;
|
||||||
use crate::protocol::Hello::{self, *};
|
use crate::protocol::Hello::{self, *};
|
||||||
|
@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
||||||
struct RunDataChannelArgs<T: Transport> {
|
struct RunDataChannelArgs<T: Transport> {
|
||||||
session_key: Nonce,
|
session_key: Nonce,
|
||||||
remote_addr: String,
|
remote_addr: String,
|
||||||
local_addr: String,
|
|
||||||
connector: Arc<T>,
|
connector: Arc<T>,
|
||||||
socket_opts: SocketOpts,
|
socket_opts: SocketOpts,
|
||||||
|
service: ClientServiceConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_data_channel_handshake<T: Transport>(
|
async fn do_data_channel_handshake<T: Transport>(
|
||||||
|
@ -201,10 +201,16 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
|
||||||
// Forward
|
// Forward
|
||||||
match read_data_cmd(&mut conn).await? {
|
match read_data_cmd(&mut conn).await? {
|
||||||
DataChannelCmd::StartForwardTcp => {
|
DataChannelCmd::StartForwardTcp => {
|
||||||
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
|
if args.service.service_type != ServiceType::Tcp {
|
||||||
|
bail!("Expect TCP traffic. Please check the configuration.")
|
||||||
|
}
|
||||||
|
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
|
||||||
}
|
}
|
||||||
DataChannelCmd::StartForwardUdp => {
|
DataChannelCmd::StartForwardUdp => {
|
||||||
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
|
if args.service.service_type != ServiceType::Udp {
|
||||||
|
bail!("Expect UDP traffic. Please check the configuration.")
|
||||||
|
}
|
||||||
|
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -427,15 +433,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
|
||||||
info!("Control channel established");
|
info!("Control channel established");
|
||||||
|
|
||||||
let remote_addr = self.remote_addr.clone();
|
let remote_addr = self.remote_addr.clone();
|
||||||
let local_addr = self.service.local_addr.clone();
|
|
||||||
// Socket options for the data channel
|
// Socket options for the data channel
|
||||||
let socket_opts = SocketOpts::from_client_cfg(&self.service);
|
let socket_opts = SocketOpts::from_client_cfg(&self.service);
|
||||||
let data_ch_args = Arc::new(RunDataChannelArgs {
|
let data_ch_args = Arc::new(RunDataChannelArgs {
|
||||||
session_key,
|
session_key,
|
||||||
remote_addr,
|
remote_addr,
|
||||||
local_addr,
|
|
||||||
connector: self.transport.clone(),
|
connector: self.transport.clone(),
|
||||||
socket_opts,
|
socket_opts,
|
||||||
|
service: self.service.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -180,6 +180,24 @@ pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Resu
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to read hello")?;
|
.with_context(|| "Failed to read hello")?;
|
||||||
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
|
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
|
||||||
|
|
||||||
|
match hello {
|
||||||
|
Hello::ControlChannelHello(v, _) => {
|
||||||
|
if v != CURRENT_PROTO_VERSION {
|
||||||
|
bail!(
|
||||||
|
"Protocol version mismatched. Expected {}, got {}. Please update `rathole`.",
|
||||||
|
CURRENT_PROTO_VERSION,
|
||||||
|
v
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Hello::DataChannelHello(v, _) => {
|
||||||
|
// This assert should not fail because the version has already been
|
||||||
|
// checked by ControlChannelHello.
|
||||||
|
assert_eq!(v, CURRENT_PROTO_VERSION);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(hello)
|
Ok(hello)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue