From cdbf8781e41702805d55c07ec941c8345afcd4f4 Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Fri, 21 Jan 2022 14:35:32 +0800 Subject: [PATCH] fix: throw errors when the service type or protocol version doesn't match (#112) * fix: print errors when service types don't match * fix: validate the protocol version when handshake --- src/client.rs | 17 +++++++++++------ src/protocol.rs | 20 +++++++++++++++++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/client.rs b/src/client.rs index 0c19519..3501b6f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; +use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType}; use crate::config_watcher::ServiceChange; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; @@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> { struct RunDataChannelArgs { session_key: Nonce, remote_addr: String, - local_addr: String, connector: Arc, socket_opts: SocketOpts, + service: ClientServiceConfig, } async fn do_data_channel_handshake( @@ -201,10 +201,16 @@ async fn run_data_channel(args: Arc>) -> Res // Forward match read_data_cmd(&mut conn).await? { DataChannelCmd::StartForwardTcp => { - run_data_channel_for_tcp::(conn, &args.local_addr).await?; + if args.service.service_type != ServiceType::Tcp { + bail!("Expect TCP traffic. Please check the configuration.") + } + run_data_channel_for_tcp::(conn, &args.service.local_addr).await?; } DataChannelCmd::StartForwardUdp => { - run_data_channel_for_udp::(conn, &args.local_addr).await?; + if args.service.service_type != ServiceType::Udp { + bail!("Expect UDP traffic. Please check the configuration.") + } + run_data_channel_for_udp::(conn, &args.service.local_addr).await?; } } Ok(()) @@ -427,15 +433,14 @@ impl ControlChannel { info!("Control channel established"); let remote_addr = self.remote_addr.clone(); - let local_addr = self.service.local_addr.clone(); // Socket options for the data channel let socket_opts = SocketOpts::from_client_cfg(&self.service); let data_ch_args = Arc::new(RunDataChannelArgs { session_key, remote_addr, - local_addr, connector: self.transport.clone(), socket_opts, + service: self.service.clone(), }); loop { diff --git a/src/protocol.rs b/src/protocol.rs index 3288621..d766c89 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,6 +1,6 @@ pub const HASH_WIDTH_IN_BYTES: usize = 32; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use bytes::{Bytes, BytesMut}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -180,6 +180,24 @@ pub async fn read_hello(conn: &mut T) -> Resu .await .with_context(|| "Failed to read hello")?; let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?; + + match hello { + Hello::ControlChannelHello(v, _) => { + if v != CURRENT_PROTO_VERSION { + bail!( + "Protocol version mismatched. Expected {}, got {}. Please update `rathole`.", + CURRENT_PROTO_VERSION, + v + ); + } + } + Hello::DataChannelHello(v, _) => { + // This assert should not fail because the version has already been + // checked by ControlChannelHello. + assert_eq!(v, CURRENT_PROTO_VERSION); + } + } + Ok(hello) }