mirror of https://github.com/rapiz1/rathole.git
fix: cancel safety
This commit is contained in:
parent
a071b0786b
commit
d0d4f61efd
|
@ -13,7 +13,7 @@ use bytes::{Bytes, BytesMut};
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||
use tokio::time::{self, Duration};
|
||||
|
|
|
@ -6,6 +6,7 @@ use lazy_static::lazy_static;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::trace;
|
||||
|
||||
type ProtocolVersion = u8;
|
||||
const PROTO_V0: u8 = 0u8;
|
||||
|
@ -70,12 +71,14 @@ pub struct UdpTraffic {
|
|||
|
||||
impl UdpTraffic {
|
||||
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
|
||||
let v = bincode::serialize(&UdpHeader {
|
||||
let hdr = UdpHeader {
|
||||
from: self.from,
|
||||
len: self.data.len() as UdpPacketLen,
|
||||
})
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
let v = bincode::serialize(&hdr).unwrap();
|
||||
|
||||
trace!("Write {:?} of length {}", hdr, v.len());
|
||||
writer.write_u16(v.len() as u16).await?;
|
||||
writer.write_all(&v).await?;
|
||||
|
||||
|
@ -90,12 +93,14 @@ impl UdpTraffic {
|
|||
from: SocketAddr,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
let v = bincode::serialize(&UdpHeader {
|
||||
let hdr = UdpHeader {
|
||||
from,
|
||||
len: data.len() as UdpPacketLen,
|
||||
})
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
let v = bincode::serialize(&hdr).unwrap();
|
||||
|
||||
trace!("Write {:?} of length {}", hdr, v.len());
|
||||
writer.write_u16(v.len() as u16).await?;
|
||||
writer.write_all(&v).await?;
|
||||
|
||||
|
@ -104,24 +109,25 @@ impl UdpTraffic {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
|
||||
let len = reader.read_u16().await? as usize;
|
||||
|
||||
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u16) -> Result<UdpTraffic> {
|
||||
let mut buf = Vec::new();
|
||||
buf.resize(len, 0);
|
||||
buf.resize(hdr_len as usize, 0);
|
||||
reader
|
||||
.read_exact(&mut buf)
|
||||
.await
|
||||
.with_context(|| "Failed to read udp header")?;
|
||||
let header: UdpHeader =
|
||||
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
|
||||
|
||||
let hdr: UdpHeader =
|
||||
bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?;
|
||||
|
||||
trace!("hdr {:?}", hdr);
|
||||
|
||||
let mut data = BytesMut::new();
|
||||
data.resize(header.len as usize, 0);
|
||||
data.resize(hdr.len as usize, 0);
|
||||
reader.read_exact(&mut data).await?;
|
||||
|
||||
Ok(UdpTraffic {
|
||||
from: header.from,
|
||||
from: hdr.from,
|
||||
data: data.freeze(),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -14,10 +14,9 @@ use backoff::ExponentialBackoff;
|
|||
|
||||
use rand::RngCore;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tokio::time;
|
||||
|
@ -618,10 +617,10 @@ async fn run_udp_connection_pool<T: Transport>(
|
|||
},
|
||||
|
||||
// Forward outbound traffic from the client to the visitor
|
||||
t = UdpTraffic::read(&mut conn) => {
|
||||
let t = t?;
|
||||
hdr_len = conn.read_u16() => {
|
||||
let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
|
||||
l.send_to(&t.data, t.from).await?;
|
||||
},
|
||||
}
|
||||
|
||||
_ = shutdown_rx.recv() => {
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue