fix: cancel safety

This commit is contained in:
Yujia Qiao 2022-01-07 18:08:02 +08:00 committed by Yujia Qiao
parent a071b0786b
commit d0d4f61efd
3 changed files with 25 additions and 20 deletions

View File

@ -13,7 +13,7 @@ use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration};

View File

@ -6,6 +6,7 @@ use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;
type ProtocolVersion = u8;
const PROTO_V0: u8 = 0u8;
@ -70,12 +71,14 @@ pub struct UdpTraffic {
impl UdpTraffic {
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
let hdr = UdpHeader {
from: self.from,
len: self.data.len() as UdpPacketLen,
})
.unwrap();
};
let v = bincode::serialize(&hdr).unwrap();
trace!("Write {:?} of length {}", hdr, v.len());
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;
@ -90,12 +93,14 @@ impl UdpTraffic {
from: SocketAddr,
data: &[u8],
) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
let hdr = UdpHeader {
from,
len: data.len() as UdpPacketLen,
})
.unwrap();
};
let v = bincode::serialize(&hdr).unwrap();
trace!("Write {:?} of length {}", hdr, v.len());
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;
@ -104,24 +109,25 @@ impl UdpTraffic {
Ok(())
}
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
let len = reader.read_u16().await? as usize;
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u16) -> Result<UdpTraffic> {
let mut buf = Vec::new();
buf.resize(len, 0);
buf.resize(hdr_len as usize, 0);
reader
.read_exact(&mut buf)
.await
.with_context(|| "Failed to read udp header")?;
let header: UdpHeader =
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
let hdr: UdpHeader =
bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?;
trace!("hdr {:?}", hdr);
let mut data = BytesMut::new();
data.resize(header.len as usize, 0);
data.resize(hdr.len as usize, 0);
reader.read_exact(&mut data).await?;
Ok(UdpTraffic {
from: header.from,
from: hdr.from,
data: data.freeze(),
})
}

View File

@ -14,10 +14,9 @@ use backoff::ExponentialBackoff;
use rand::RngCore;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time;
@ -618,10 +617,10 @@ async fn run_udp_connection_pool<T: Transport>(
},
// Forward outbound traffic from the client to the visitor
t = UdpTraffic::read(&mut conn) => {
let t = t?;
hdr_len = conn.read_u16() => {
let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
l.send_to(&t.data, t.from).await?;
},
}
_ = shutdown_rx.recv() => {
break;