fix: throw errors when the service type or protocol version doesn't match (#112)

* fix: print errors when service types don't match

* fix: validate the protocol version when handshake
This commit is contained in:
Yujia Qiao 2022-01-21 14:35:32 +08:00 committed by GitHub
parent a66502d33b
commit cdbf8781e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 7 deletions

View File

@ -1,4 +1,4 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
use crate::config_watcher::ServiceChange; use crate::config_watcher::ServiceChange;
use crate::helper::udp_connect; use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *}; use crate::protocol::Hello::{self, *};
@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
struct RunDataChannelArgs<T: Transport> { struct RunDataChannelArgs<T: Transport> {
session_key: Nonce, session_key: Nonce,
remote_addr: String, remote_addr: String,
local_addr: String,
connector: Arc<T>, connector: Arc<T>,
socket_opts: SocketOpts, socket_opts: SocketOpts,
service: ClientServiceConfig,
} }
async fn do_data_channel_handshake<T: Transport>( async fn do_data_channel_handshake<T: Transport>(
@ -201,10 +201,16 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
// Forward // Forward
match read_data_cmd(&mut conn).await? { match read_data_cmd(&mut conn).await? {
DataChannelCmd::StartForwardTcp => { DataChannelCmd::StartForwardTcp => {
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?; if args.service.service_type != ServiceType::Tcp {
bail!("Expect TCP traffic. Please check the configuration.")
}
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
} }
DataChannelCmd::StartForwardUdp => { DataChannelCmd::StartForwardUdp => {
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?; if args.service.service_type != ServiceType::Udp {
bail!("Expect UDP traffic. Please check the configuration.")
}
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
} }
} }
Ok(()) Ok(())
@ -427,15 +433,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
info!("Control channel established"); info!("Control channel established");
let remote_addr = self.remote_addr.clone(); let remote_addr = self.remote_addr.clone();
let local_addr = self.service.local_addr.clone();
// Socket options for the data channel // Socket options for the data channel
let socket_opts = SocketOpts::from_client_cfg(&self.service); let socket_opts = SocketOpts::from_client_cfg(&self.service);
let data_ch_args = Arc::new(RunDataChannelArgs { let data_ch_args = Arc::new(RunDataChannelArgs {
session_key, session_key,
remote_addr, remote_addr,
local_addr,
connector: self.transport.clone(), connector: self.transport.clone(),
socket_opts, socket_opts,
service: self.service.clone(),
}); });
loop { loop {

View File

@ -1,6 +1,6 @@
pub const HASH_WIDTH_IN_BYTES: usize = 32; pub const HASH_WIDTH_IN_BYTES: usize = 32;
use anyhow::{Context, Result}; use anyhow::{bail, Context, Result};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -180,6 +180,24 @@ pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Resu
.await .await
.with_context(|| "Failed to read hello")?; .with_context(|| "Failed to read hello")?;
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?; let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
match hello {
Hello::ControlChannelHello(v, _) => {
if v != CURRENT_PROTO_VERSION {
bail!(
"Protocol version mismatched. Expected {}, got {}. Please update `rathole`.",
CURRENT_PROTO_VERSION,
v
);
}
}
Hello::DataChannelHello(v, _) => {
// This assert should not fail because the version has already been
// checked by ControlChannelHello.
assert_eq!(v, CURRENT_PROTO_VERSION);
}
}
Ok(hello) Ok(hello)
} }