From d0d4f61efd7b6f7cd30a6f74d03f138c9bb55b4c Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Fri, 7 Jan 2022 18:08:02 +0800 Subject: [PATCH] fix: cancel safety --- src/client.rs | 2 +- src/protocol.rs | 34 ++++++++++++++++++++-------------- src/server.rs | 9 ++++----- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/client.rs b/src/client.rs index 6ee34d5..dc6710c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,7 +13,7 @@ use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration}; diff --git a/src/protocol.rs b/src/protocol.rs index 883b654..9be620e 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tracing::trace; type ProtocolVersion = u8; const PROTO_V0: u8 = 0u8; @@ -70,12 +71,14 @@ pub struct UdpTraffic { impl UdpTraffic { pub async fn write(&self, writer: &mut T) -> Result<()> { - let v = bincode::serialize(&UdpHeader { + let hdr = UdpHeader { from: self.from, len: self.data.len() as UdpPacketLen, - }) - .unwrap(); + }; + let v = bincode::serialize(&hdr).unwrap(); + + trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u16(v.len() as u16).await?; writer.write_all(&v).await?; @@ -90,12 +93,14 @@ impl UdpTraffic { from: SocketAddr, data: &[u8], ) -> Result<()> { - let v = bincode::serialize(&UdpHeader { + let hdr = UdpHeader { from, len: data.len() as UdpPacketLen, - }) - .unwrap(); + }; + let v = bincode::serialize(&hdr).unwrap(); + + trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u16(v.len() as u16).await?; writer.write_all(&v).await?; @@ -104,24 +109,25 @@ impl UdpTraffic { Ok(()) } - pub async fn read(reader: &mut T) -> Result { - let len = reader.read_u16().await? as usize; - + pub async fn read(reader: &mut T, hdr_len: u16) -> Result { let mut buf = Vec::new(); - buf.resize(len, 0); + buf.resize(hdr_len as usize, 0); reader .read_exact(&mut buf) .await .with_context(|| "Failed to read udp header")?; - let header: UdpHeader = - bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?; + + let hdr: UdpHeader = + bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?; + + trace!("hdr {:?}", hdr); let mut data = BytesMut::new(); - data.resize(header.len as usize, 0); + data.resize(hdr.len as usize, 0); reader.read_exact(&mut data).await?; Ok(UdpTraffic { - from: header.from, + from: hdr.from, data: data.freeze(), }) } diff --git a/src/server.rs b/src/server.rs index ef2c85d..0003c9b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,10 +14,9 @@ use backoff::ExponentialBackoff; use rand::RngCore; use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time; @@ -618,10 +617,10 @@ async fn run_udp_connection_pool( }, // Forward outbound traffic from the client to the visitor - t = UdpTraffic::read(&mut conn) => { - let t = t?; + hdr_len = conn.read_u16() => { + let t = UdpTraffic::read(&mut conn, hdr_len?).await?; l.send_to(&t.data, t.from).await?; - }, + } _ = shutdown_rx.recv() => { break;