mirror of https://github.com/rapiz1/rathole.git
fix: throw errors when the service type or protocol version doesn't match (#112)
* fix: print errors when service types don't match * fix: validate the protocol version when handshake
This commit is contained in:
parent
a66502d33b
commit
cdbf8781e4
|
@ -1,4 +1,4 @@
|
|||
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
|
||||
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
|
||||
use crate::config_watcher::ServiceChange;
|
||||
use crate::helper::udp_connect;
|
||||
use crate::protocol::Hello::{self, *};
|
||||
|
@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
|
|||
struct RunDataChannelArgs<T: Transport> {
|
||||
session_key: Nonce,
|
||||
remote_addr: String,
|
||||
local_addr: String,
|
||||
connector: Arc<T>,
|
||||
socket_opts: SocketOpts,
|
||||
service: ClientServiceConfig,
|
||||
}
|
||||
|
||||
async fn do_data_channel_handshake<T: Transport>(
|
||||
|
@ -201,10 +201,16 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
|
|||
// Forward
|
||||
match read_data_cmd(&mut conn).await? {
|
||||
DataChannelCmd::StartForwardTcp => {
|
||||
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
|
||||
if args.service.service_type != ServiceType::Tcp {
|
||||
bail!("Expect TCP traffic. Please check the configuration.")
|
||||
}
|
||||
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
|
||||
}
|
||||
DataChannelCmd::StartForwardUdp => {
|
||||
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
|
||||
if args.service.service_type != ServiceType::Udp {
|
||||
bail!("Expect UDP traffic. Please check the configuration.")
|
||||
}
|
||||
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -427,15 +433,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
|
|||
info!("Control channel established");
|
||||
|
||||
let remote_addr = self.remote_addr.clone();
|
||||
let local_addr = self.service.local_addr.clone();
|
||||
// Socket options for the data channel
|
||||
let socket_opts = SocketOpts::from_client_cfg(&self.service);
|
||||
let data_ch_args = Arc::new(RunDataChannelArgs {
|
||||
session_key,
|
||||
remote_addr,
|
||||
local_addr,
|
||||
connector: self.transport.clone(),
|
||||
socket_opts,
|
||||
service: self.service.clone(),
|
||||
});
|
||||
|
||||
loop {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -180,6 +180,24 @@ pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Resu
|
|||
.await
|
||||
.with_context(|| "Failed to read hello")?;
|
||||
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
|
||||
|
||||
match hello {
|
||||
Hello::ControlChannelHello(v, _) => {
|
||||
if v != CURRENT_PROTO_VERSION {
|
||||
bail!(
|
||||
"Protocol version mismatched. Expected {}, got {}. Please update `rathole`.",
|
||||
CURRENT_PROTO_VERSION,
|
||||
v
|
||||
);
|
||||
}
|
||||
}
|
||||
Hello::DataChannelHello(v, _) => {
|
||||
// This assert should not fail because the version has already been
|
||||
// checked by ControlChannelHello.
|
||||
assert_eq!(v, CURRENT_PROTO_VERSION);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hello)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue