fix: throw errors when the service type or protocol version doesn't match (#112)

* fix: print errors when service types don't match

* fix: validate the protocol version when handshake
This commit is contained in:
Yujia Qiao 2022-01-21 14:35:32 +08:00 committed by GitHub
parent a66502d33b
commit cdbf8781e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 7 deletions

View File

@ -1,4 +1,4 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
use crate::config_watcher::ServiceChange;
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
struct RunDataChannelArgs<T: Transport> {
session_key: Nonce,
remote_addr: String,
local_addr: String,
connector: Arc<T>,
socket_opts: SocketOpts,
service: ClientServiceConfig,
}
async fn do_data_channel_handshake<T: Transport>(
@ -201,10 +201,16 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
// Forward
match read_data_cmd(&mut conn).await? {
DataChannelCmd::StartForwardTcp => {
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
if args.service.service_type != ServiceType::Tcp {
bail!("Expect TCP traffic. Please check the configuration.")
}
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
}
DataChannelCmd::StartForwardUdp => {
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
if args.service.service_type != ServiceType::Udp {
bail!("Expect UDP traffic. Please check the configuration.")
}
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
}
}
Ok(())
@ -427,15 +433,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
info!("Control channel established");
let remote_addr = self.remote_addr.clone();
let local_addr = self.service.local_addr.clone();
// Socket options for the data channel
let socket_opts = SocketOpts::from_client_cfg(&self.service);
let data_ch_args = Arc::new(RunDataChannelArgs {
session_key,
remote_addr,
local_addr,
connector: self.transport.clone(),
socket_opts,
service: self.service.clone(),
});
loop {

View File

@ -1,6 +1,6 @@
pub const HASH_WIDTH_IN_BYTES: usize = 32;
use anyhow::{Context, Result};
use anyhow::{bail, Context, Result};
use bytes::{Bytes, BytesMut};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
@ -180,6 +180,24 @@ pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Resu
.await
.with_context(|| "Failed to read hello")?;
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
match hello {
Hello::ControlChannelHello(v, _) => {
if v != CURRENT_PROTO_VERSION {
bail!(
"Protocol version mismatched. Expected {}, got {}. Please update `rathole`.",
CURRENT_PROTO_VERSION,
v
);
}
}
Hello::DataChannelHello(v, _) => {
// This assert should not fail because the version has already been
// checked by ControlChannelHello.
assert_eq!(v, CURRENT_PROTO_VERSION);
}
}
Ok(hello)
}